whispercpp 1.3.1 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +4 -3
- data/README.md +92 -31
- data/Rakefile +26 -7
- data/ext/.gitignore +5 -7
- data/ext/dependencies.rb +61 -0
- data/ext/extconf.rb +21 -198
- data/ext/options.rb +221 -0
- data/ext/ruby_whisper.c +159 -0
- data/ext/ruby_whisper.h +17 -2
- data/ext/ruby_whisper_context.c +641 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1301 -0
- data/ext/ruby_whisper_segment.c +143 -0
- data/ext/ruby_whisper_transcribe.cpp +87 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/.dockerignore +3 -0
- data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +251 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/ci/run.sh +336 -0
- data/ext/sources/close-issue.yml +28 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
- data/ext/sources/examples/addon.node/addon.cpp +438 -0
- data/ext/sources/examples/addon.node/index.js +54 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +175 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1294 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +776 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +168 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +467 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +223 -0
- data/ext/sources/examples/server/CMakeLists.txt +12 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1091 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +429 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
- data/ext/sources/examples/talk-llama/llama-context.h +276 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
- data/ext/sources/examples/talk-llama/llama-model.h +425 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
- data/ext/sources/examples/talk-llama/llama.cpp +354 -0
- data/ext/sources/examples/talk-llama/llama.h +1377 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +390 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +26 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
- data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
- data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
- data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
- data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
- data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
- data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +346 -0
- data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
- data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
- data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
- data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
- data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
- data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
- data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
- data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
- data/ext/sources/ggml/src/gguf.cpp +1330 -0
- data/ext/{include → sources/include}/whisper.h +68 -2
- data/ext/sources/src/CMakeLists.txt +143 -0
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{src → sources/src}/whisper.cpp +1905 -374
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +33 -5
- data/lib/whisper/model/uri.rb +149 -128
- data/sig/whisper.rbs +480 -0
- data/tests/helper.rb +28 -0
- data/tests/test_callback.rb +45 -3
- data/tests/test_error.rb +2 -2
- data/tests/test_model.rb +38 -0
- data/tests/test_package.rb +18 -3
- data/tests/test_params.rb +145 -8
- data/tests/test_segment.rb +10 -19
- data/tests/test_vad.rb +19 -0
- data/tests/test_vad_params.rb +103 -0
- data/tests/test_whisper.rb +37 -37
- data/whispercpp.gemspec +5 -4
- metadata +766 -111
- data/ext/cpu.mk +0 -9
- data/ext/examples/dr_wav.h +0 -8815
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
- data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
- data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
- data/ext/metal-embed.mk +0 -17
- data/ext/metal.mk +0 -6
- data/ext/ruby_whisper.cpp +0 -1909
- data/ext/scripts/get-flags.mk +0 -38
- data/lib/whisper.rb +0 -2
- /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
- /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -3,8 +3,7 @@
|
|
3
3
|
#if defined(GGML_METAL_EMBED_LIBRARY)
|
4
4
|
__embed_ggml-common.h__
|
5
5
|
#else
|
6
|
-
|
7
|
-
#include "../ggml-common.h"
|
6
|
+
#include "ggml-common.h"
|
8
7
|
#endif
|
9
8
|
#include "ggml-metal-impl.h"
|
10
9
|
|
@@ -49,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
49
48
|
|
50
49
|
template <typename type4>
|
51
50
|
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
52
|
-
reg = (type4)(*(src
|
51
|
+
reg = (type4)(*(src));
|
53
52
|
}
|
54
53
|
|
55
54
|
#if defined(GGML_METAL_USE_BF16)
|
@@ -57,6 +56,11 @@ template <typename type4x4>
|
|
57
56
|
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
58
57
|
reg = (type4x4)(*src);
|
59
58
|
}
|
59
|
+
|
60
|
+
template <typename type4>
|
61
|
+
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
|
62
|
+
reg = (type4)(*(src));
|
63
|
+
}
|
60
64
|
#endif
|
61
65
|
|
62
66
|
template <typename type4x4>
|
@@ -373,24 +377,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
373
377
|
template <typename type4x4>
|
374
378
|
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
375
379
|
const half d_all = xb->d;
|
376
|
-
device const
|
377
|
-
device const
|
380
|
+
device const uint16_t * ql = (device const uint16_t *)xb->ql;
|
381
|
+
device const uint16_t * qh = (device const uint16_t *)xb->qh;
|
378
382
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
379
383
|
|
380
|
-
ql = ql +
|
381
|
-
qh = qh +
|
384
|
+
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
|
385
|
+
qh = qh + 16*(il/8) + 8*(il&1);
|
382
386
|
float sc = scales[(il%2) + 2 * ((il/2))];
|
383
387
|
il = (il/2) & 3;
|
384
388
|
|
385
|
-
const
|
386
|
-
const
|
387
|
-
const float coef = il>1 ? 1.f/16.f : 1.f;
|
389
|
+
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
|
390
|
+
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
|
388
391
|
const float ml = d_all * sc * 32.f;
|
389
|
-
const float
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
392
|
+
const float dl0 = d_all * sc;
|
393
|
+
const float dl1 = dl0 / 256.f;
|
394
|
+
const float dl2 = dl0 / (256.f * 256.f);
|
395
|
+
const float dl3 = dl0 / (256.f * 256.f * 256.f);
|
396
|
+
const uint8_t shr_h = il>2 ? 2 : 0;
|
397
|
+
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
|
398
|
+
const uint8_t shr_l = il>1 ? 4 : 0;
|
399
|
+
for (int i = 0; i < 4; ++i) {
|
400
|
+
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
|
401
|
+
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
|
402
|
+
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
|
403
|
+
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
|
404
|
+
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
|
405
|
+
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
|
406
|
+
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
|
394
407
|
}
|
395
408
|
}
|
396
409
|
|
@@ -843,6 +856,7 @@ kernel void kernel_tanh(
|
|
843
856
|
constant float GELU_COEF_A = 0.044715f;
|
844
857
|
constant float GELU_QUICK_COEF = -1.702f;
|
845
858
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
859
|
+
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
846
860
|
|
847
861
|
kernel void kernel_gelu(
|
848
862
|
device const float * src0,
|
@@ -884,6 +898,42 @@ kernel void kernel_gelu_quick_4(
|
|
884
898
|
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
885
899
|
}
|
886
900
|
|
901
|
+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
902
|
+
// ref: https://www.johndcook.com/blog/python_erf/
|
903
|
+
constant float p_erf = 0.3275911f;
|
904
|
+
constant float a1_erf = 0.254829592f;
|
905
|
+
constant float a2_erf = -0.284496736f;
|
906
|
+
constant float a3_erf = 1.421413741f;
|
907
|
+
constant float a4_erf = -1.453152027f;
|
908
|
+
constant float a5_erf = 1.061405429f;
|
909
|
+
|
910
|
+
template<typename T>
|
911
|
+
T erf_approx(T x) {
|
912
|
+
T sign_x = sign(x);
|
913
|
+
x = fabs(x);
|
914
|
+
T t = 1.0f / (1.0f + p_erf * x);
|
915
|
+
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
916
|
+
return sign_x * y;
|
917
|
+
}
|
918
|
+
|
919
|
+
kernel void kernel_gelu_erf(
|
920
|
+
device const float * src0,
|
921
|
+
device float * dst,
|
922
|
+
uint tpig[[thread_position_in_grid]]) {
|
923
|
+
device const float & x = src0[tpig];
|
924
|
+
|
925
|
+
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
|
926
|
+
}
|
927
|
+
|
928
|
+
kernel void kernel_gelu_erf_4(
|
929
|
+
device const float4 * src0,
|
930
|
+
device float4 * dst,
|
931
|
+
uint tpig[[thread_position_in_grid]]) {
|
932
|
+
device const float4 & x = src0[tpig];
|
933
|
+
|
934
|
+
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
|
935
|
+
}
|
936
|
+
|
887
937
|
kernel void kernel_silu(
|
888
938
|
device const float * src0,
|
889
939
|
device float * dst,
|
@@ -936,48 +986,32 @@ kernel void kernel_cos(
|
|
936
986
|
dst[tpig] = cos(src0[tpig]);
|
937
987
|
}
|
938
988
|
|
989
|
+
kernel void kernel_neg(
|
990
|
+
device const float * src0,
|
991
|
+
device float * dst,
|
992
|
+
uint tpig[[thread_position_in_grid]]) {
|
993
|
+
dst[tpig] = -src0[tpig];
|
994
|
+
}
|
995
|
+
|
939
996
|
kernel void kernel_sum_rows(
|
940
997
|
device const float * src0,
|
941
998
|
device float * dst,
|
942
|
-
constant
|
943
|
-
constant int64_t & ne01,
|
944
|
-
constant int64_t & ne02,
|
945
|
-
constant int64_t & ne03,
|
946
|
-
constant uint64_t & nb00,
|
947
|
-
constant uint64_t & nb01,
|
948
|
-
constant uint64_t & nb02,
|
949
|
-
constant uint64_t & nb03,
|
950
|
-
constant int64_t & ne10,
|
951
|
-
constant int64_t & ne11,
|
952
|
-
constant int64_t & ne12,
|
953
|
-
constant int64_t & ne13,
|
954
|
-
constant uint64_t & nb10,
|
955
|
-
constant uint64_t & nb11,
|
956
|
-
constant uint64_t & nb12,
|
957
|
-
constant uint64_t & nb13,
|
958
|
-
constant int64_t & ne0,
|
959
|
-
constant int64_t & ne1,
|
960
|
-
constant int64_t & ne2,
|
961
|
-
constant int64_t & ne3,
|
962
|
-
constant uint64_t & nb0,
|
963
|
-
constant uint64_t & nb1,
|
964
|
-
constant uint64_t & nb2,
|
965
|
-
constant uint64_t & nb3,
|
999
|
+
constant ggml_metal_kargs_sum_rows & args,
|
966
1000
|
uint3 tpig[[thread_position_in_grid]]) {
|
967
1001
|
int64_t i3 = tpig.z;
|
968
1002
|
int64_t i2 = tpig.y;
|
969
1003
|
int64_t i1 = tpig.x;
|
970
1004
|
|
971
|
-
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
1005
|
+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
972
1006
|
return;
|
973
1007
|
}
|
974
1008
|
|
975
|
-
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
976
|
-
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
1009
|
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
1010
|
+
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
977
1011
|
|
978
1012
|
float row_sum = 0;
|
979
1013
|
|
980
|
-
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
1014
|
+
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
981
1015
|
row_sum += src_row[i0];
|
982
1016
|
}
|
983
1017
|
|
@@ -989,36 +1023,29 @@ kernel void kernel_soft_max(
|
|
989
1023
|
device const char * src0,
|
990
1024
|
device const char * src1,
|
991
1025
|
device char * dst,
|
992
|
-
constant
|
993
|
-
constant int64_t & ne01,
|
994
|
-
constant int64_t & ne02,
|
995
|
-
constant float & scale,
|
996
|
-
constant float & max_bias,
|
997
|
-
constant float & m0,
|
998
|
-
constant float & m1,
|
999
|
-
constant uint32_t & n_head_log2,
|
1026
|
+
constant ggml_metal_kargs_soft_max & args,
|
1000
1027
|
threadgroup float * buf [[threadgroup(0)]],
|
1001
1028
|
uint tgpig[[threadgroup_position_in_grid]],
|
1002
1029
|
uint tpitg[[thread_position_in_threadgroup]],
|
1003
1030
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
1004
1031
|
uint tiisg[[thread_index_in_simdgroup]],
|
1005
1032
|
uint ntg[[threads_per_threadgroup]]) {
|
1006
|
-
const int64_t i03 = (tgpig) / (ne02*ne01);
|
1007
|
-
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
1008
|
-
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
1033
|
+
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
1034
|
+
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
1035
|
+
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
1009
1036
|
|
1010
|
-
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
1011
|
-
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
1012
|
-
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
1037
|
+
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
1038
|
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
|
1039
|
+
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
1013
1040
|
|
1014
1041
|
float slope = 1.0f;
|
1015
1042
|
|
1016
1043
|
// ALiBi
|
1017
|
-
if (max_bias > 0.0f) {
|
1044
|
+
if (args.max_bias > 0.0f) {
|
1018
1045
|
const int64_t h = i02;
|
1019
1046
|
|
1020
|
-
const float base = h < n_head_log2 ? m0 : m1;
|
1021
|
-
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
1047
|
+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
1048
|
+
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
1022
1049
|
|
1023
1050
|
slope = pow(base, exp);
|
1024
1051
|
}
|
@@ -1026,8 +1053,8 @@ kernel void kernel_soft_max(
|
|
1026
1053
|
// parallel max
|
1027
1054
|
float lmax = -INFINITY;
|
1028
1055
|
|
1029
|
-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
1030
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
1056
|
+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
1057
|
+
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
1031
1058
|
}
|
1032
1059
|
|
1033
1060
|
// find the max value in the block
|
@@ -1051,14 +1078,14 @@ kernel void kernel_soft_max(
|
|
1051
1078
|
|
1052
1079
|
// parallel sum
|
1053
1080
|
float lsum = 0.0f;
|
1054
|
-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
1055
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
1081
|
+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
1082
|
+
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
1056
1083
|
lsum += exp_psrc0;
|
1057
1084
|
pdst[i00] = exp_psrc0;
|
1058
1085
|
}
|
1059
1086
|
|
1060
1087
|
// This barrier fixes a failing test
|
1061
|
-
// ref: https://github.com/
|
1088
|
+
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
|
1062
1089
|
threadgroup_barrier(mem_flags::mem_none);
|
1063
1090
|
|
1064
1091
|
float sum = simd_sum(lsum);
|
@@ -1082,7 +1109,7 @@ kernel void kernel_soft_max(
|
|
1082
1109
|
|
1083
1110
|
const float inv_sum = 1.0f/sum;
|
1084
1111
|
|
1085
|
-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
1112
|
+
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
1086
1113
|
pdst[i00] *= inv_sum;
|
1087
1114
|
}
|
1088
1115
|
}
|
@@ -1092,35 +1119,28 @@ kernel void kernel_soft_max_4(
|
|
1092
1119
|
device const char * src0,
|
1093
1120
|
device const char * src1,
|
1094
1121
|
device char * dst,
|
1095
|
-
constant
|
1096
|
-
constant int64_t & ne01,
|
1097
|
-
constant int64_t & ne02,
|
1098
|
-
constant float & scale,
|
1099
|
-
constant float & max_bias,
|
1100
|
-
constant float & m0,
|
1101
|
-
constant float & m1,
|
1102
|
-
constant uint32_t & n_head_log2,
|
1122
|
+
constant ggml_metal_kargs_soft_max & args,
|
1103
1123
|
threadgroup float * buf [[threadgroup(0)]],
|
1104
1124
|
uint tgpig[[threadgroup_position_in_grid]],
|
1105
1125
|
uint tpitg[[thread_position_in_threadgroup]],
|
1106
1126
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
1107
1127
|
uint tiisg[[thread_index_in_simdgroup]],
|
1108
1128
|
uint ntg[[threads_per_threadgroup]]) {
|
1109
|
-
const int64_t i03 = (tgpig) / (ne02*ne01);
|
1110
|
-
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
1111
|
-
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
1129
|
+
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
1130
|
+
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
1131
|
+
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
1112
1132
|
|
1113
|
-
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
1114
|
-
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
1115
|
-
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
1133
|
+
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
1134
|
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
|
1135
|
+
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
1116
1136
|
|
1117
1137
|
float slope = 1.0f;
|
1118
1138
|
|
1119
|
-
if (max_bias > 0.0f) {
|
1139
|
+
if (args.max_bias > 0.0f) {
|
1120
1140
|
const int64_t h = i02;
|
1121
1141
|
|
1122
|
-
const float base = h < n_head_log2 ? m0 : m1;
|
1123
|
-
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
1142
|
+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
1143
|
+
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
1124
1144
|
|
1125
1145
|
slope = pow(base, exp);
|
1126
1146
|
}
|
@@ -1128,8 +1148,8 @@ kernel void kernel_soft_max_4(
|
|
1128
1148
|
// parallel max
|
1129
1149
|
float4 lmax4 = -INFINITY;
|
1130
1150
|
|
1131
|
-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
1132
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
1151
|
+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
1152
|
+
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
1133
1153
|
}
|
1134
1154
|
|
1135
1155
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
@@ -1154,8 +1174,8 @@ kernel void kernel_soft_max_4(
|
|
1154
1174
|
|
1155
1175
|
// parallel sum
|
1156
1176
|
float4 lsum4 = 0.0f;
|
1157
|
-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
1158
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
1177
|
+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
1178
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
1159
1179
|
lsum4 += exp_psrc4;
|
1160
1180
|
pdst4[i00] = exp_psrc4;
|
1161
1181
|
}
|
@@ -1163,7 +1183,7 @@ kernel void kernel_soft_max_4(
|
|
1163
1183
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
1164
1184
|
|
1165
1185
|
// This barrier fixes a failing test
|
1166
|
-
// ref: https://github.com/
|
1186
|
+
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
|
1167
1187
|
threadgroup_barrier(mem_flags::mem_none);
|
1168
1188
|
|
1169
1189
|
float sum = simd_sum(lsum);
|
@@ -1187,7 +1207,7 @@ kernel void kernel_soft_max_4(
|
|
1187
1207
|
|
1188
1208
|
const float inv_sum = 1.0f/sum;
|
1189
1209
|
|
1190
|
-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
1210
|
+
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
1191
1211
|
pdst4[i00] *= inv_sum;
|
1192
1212
|
}
|
1193
1213
|
}
|
@@ -1203,27 +1223,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
|
|
1203
1223
|
kernel void kernel_diag_mask_inf(
|
1204
1224
|
device const float * src0,
|
1205
1225
|
device float * dst,
|
1206
|
-
constant
|
1207
|
-
constant int64_t & ne01,
|
1208
|
-
constant int & n_past,
|
1226
|
+
constant ggml_metal_kargs_diag_mask_inf & args,
|
1209
1227
|
uint3 tpig[[thread_position_in_grid]]) {
|
1210
1228
|
const int64_t i02 = tpig[2];
|
1211
1229
|
const int64_t i01 = tpig[1];
|
1212
1230
|
const int64_t i00 = tpig[0];
|
1213
1231
|
|
1214
|
-
if (i00 > n_past + i01) {
|
1215
|
-
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
1232
|
+
if (i00 > args.n_past + i01) {
|
1233
|
+
dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
|
1216
1234
|
} else {
|
1217
|
-
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
1235
|
+
dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
|
1218
1236
|
}
|
1219
1237
|
}
|
1220
1238
|
|
1221
1239
|
kernel void kernel_diag_mask_inf_8(
|
1222
1240
|
device const float4 * src0,
|
1223
1241
|
device float4 * dst,
|
1224
|
-
constant
|
1225
|
-
constant int64_t & ne01,
|
1226
|
-
constant int & n_past,
|
1242
|
+
constant ggml_metal_kargs_diag_mask_inf & args,
|
1227
1243
|
uint3 tpig[[thread_position_in_grid]]) {
|
1228
1244
|
|
1229
1245
|
const int64_t i = 2*tpig[0];
|
@@ -1231,42 +1247,26 @@ kernel void kernel_diag_mask_inf_8(
|
|
1231
1247
|
dst[i+0] = src0[i+0];
|
1232
1248
|
dst[i+1] = src0[i+1];
|
1233
1249
|
int64_t i4 = 4*i;
|
1234
|
-
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
1235
|
-
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
1250
|
+
const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
|
1251
|
+
const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00;
|
1236
1252
|
const int64_t i00 = i4;
|
1237
1253
|
for (int k = 3; k >= 0; --k) {
|
1238
|
-
if (i00 + 4 + k <= n_past + i01) {
|
1254
|
+
if (i00 + 4 + k <= args.n_past + i01) {
|
1239
1255
|
break;
|
1240
1256
|
}
|
1241
1257
|
dst[i+1][k] = -INFINITY;
|
1242
|
-
if (i00 + k > n_past + i01) {
|
1258
|
+
if (i00 + k > args.n_past + i01) {
|
1243
1259
|
dst[i][k] = -INFINITY;
|
1244
1260
|
}
|
1245
1261
|
}
|
1246
1262
|
}
|
1247
1263
|
|
1248
1264
|
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
1249
|
-
// TODO: optimize
|
1250
1265
|
kernel void kernel_ssm_conv_f32(
|
1251
1266
|
device const void * src0,
|
1252
1267
|
device const void * src1,
|
1253
1268
|
device float * dst,
|
1254
|
-
constant
|
1255
|
-
constant int64_t & ne01,
|
1256
|
-
constant int64_t & ne02,
|
1257
|
-
constant uint64_t & nb00,
|
1258
|
-
constant uint64_t & nb01,
|
1259
|
-
constant uint64_t & nb02,
|
1260
|
-
constant int64_t & ne10,
|
1261
|
-
constant int64_t & ne11,
|
1262
|
-
constant uint64_t & nb10,
|
1263
|
-
constant uint64_t & nb11,
|
1264
|
-
constant int64_t & ne0,
|
1265
|
-
constant int64_t & ne1,
|
1266
|
-
constant int64_t & ne2,
|
1267
|
-
constant uint64_t & nb0,
|
1268
|
-
constant uint64_t & nb1,
|
1269
|
-
constant uint64_t & nb2,
|
1269
|
+
constant ggml_metal_kargs_ssm_conv & args,
|
1270
1270
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1271
1271
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
1272
1272
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -1274,15 +1274,15 @@ kernel void kernel_ssm_conv_f32(
|
|
1274
1274
|
const int64_t i2 = tgpig.y;
|
1275
1275
|
const int64_t i3 = tgpig.z;
|
1276
1276
|
|
1277
|
-
const int64_t nc = ne10;
|
1278
|
-
//const int64_t ncs = ne00;
|
1279
|
-
//const int64_t nr = ne01;
|
1280
|
-
//const int64_t n_t = ne1;
|
1281
|
-
//const int64_t n_s = ne2;
|
1277
|
+
const int64_t nc = args.ne10;
|
1278
|
+
//const int64_t ncs = args.ne00;
|
1279
|
+
//const int64_t nr = args.ne01;
|
1280
|
+
//const int64_t n_t = args.ne1;
|
1281
|
+
//const int64_t n_s = args.ne2;
|
1282
1282
|
|
1283
|
-
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
1284
|
-
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
1285
|
-
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
|
1283
|
+
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
1284
|
+
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
1285
|
+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
1286
1286
|
|
1287
1287
|
float sumf = 0.0f;
|
1288
1288
|
|
@@ -1294,7 +1294,6 @@ kernel void kernel_ssm_conv_f32(
|
|
1294
1294
|
}
|
1295
1295
|
|
1296
1296
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
1297
|
-
// TODO: optimize
|
1298
1297
|
kernel void kernel_ssm_scan_f32(
|
1299
1298
|
device const void * src0,
|
1300
1299
|
device const void * src1,
|
@@ -1303,48 +1302,27 @@ kernel void kernel_ssm_scan_f32(
|
|
1303
1302
|
device const void * src4,
|
1304
1303
|
device const void * src5,
|
1305
1304
|
device float * dst,
|
1306
|
-
constant
|
1307
|
-
constant int64_t & d_inner,
|
1308
|
-
constant int64_t & n_seq_tokens,
|
1309
|
-
constant int64_t & n_seqs,
|
1310
|
-
constant uint64_t & nb00,
|
1311
|
-
constant uint64_t & nb01,
|
1312
|
-
constant uint64_t & nb02,
|
1313
|
-
constant uint64_t & nb10,
|
1314
|
-
constant uint64_t & nb11,
|
1315
|
-
constant uint64_t & nb12,
|
1316
|
-
constant uint64_t & nb13,
|
1317
|
-
constant uint64_t & nb20,
|
1318
|
-
constant uint64_t & nb21,
|
1319
|
-
constant uint64_t & nb22,
|
1320
|
-
constant uint64_t & nb30,
|
1321
|
-
constant uint64_t & nb31,
|
1322
|
-
constant uint64_t & nb40,
|
1323
|
-
constant uint64_t & nb41,
|
1324
|
-
constant uint64_t & nb42,
|
1325
|
-
constant uint64_t & nb50,
|
1326
|
-
constant uint64_t & nb51,
|
1327
|
-
constant uint64_t & nb52,
|
1305
|
+
constant ggml_metal_kargs_ssm_scan & args,
|
1328
1306
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1329
1307
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
1330
1308
|
uint3 ntg[[threads_per_threadgroup]]) {
|
1331
1309
|
const int64_t ir = tgpig.x;
|
1332
1310
|
const int64_t i3 = tgpig.y;
|
1333
1311
|
|
1334
|
-
const int64_t nc = d_state;
|
1335
|
-
|
1336
|
-
const int64_t n_t = n_seq_tokens;
|
1337
|
-
|
1312
|
+
const int64_t nc = args.d_state;
|
1313
|
+
// const int64_t nr = args.d_inner;
|
1314
|
+
const int64_t n_t = args.n_seq_tokens;
|
1315
|
+
// const int64_t n_s = args.n_seqs;
|
1338
1316
|
|
1339
1317
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
1340
|
-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
|
1341
|
-
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
|
1342
|
-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
|
1343
|
-
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
|
1344
|
-
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
|
1345
|
-
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
|
1346
|
-
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
|
1347
|
-
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
|
1318
|
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
|
1319
|
+
device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
|
1320
|
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
|
1321
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
1322
|
+
device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
|
1323
|
+
device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
|
1324
|
+
device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
|
1325
|
+
device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
|
1348
1326
|
|
1349
1327
|
if (i2 > 0) {
|
1350
1328
|
s0 = s;
|
@@ -1366,6 +1344,184 @@ kernel void kernel_ssm_scan_f32(
|
|
1366
1344
|
}
|
1367
1345
|
}
|
1368
1346
|
|
1347
|
+
kernel void kernel_rwkv_wkv6_f32(
|
1348
|
+
device const float * k,
|
1349
|
+
device const float * v,
|
1350
|
+
device const float * r,
|
1351
|
+
device const float * tf,
|
1352
|
+
device const float * td,
|
1353
|
+
device const float * state_in,
|
1354
|
+
device float * dst,
|
1355
|
+
constant uint & B,
|
1356
|
+
constant uint & T,
|
1357
|
+
constant uint & C,
|
1358
|
+
constant uint & H,
|
1359
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1360
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1361
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1362
|
+
|
1363
|
+
const uint head_size = 64; // TODO: support head_size = 128
|
1364
|
+
const uint batch_id = tgpig.x / H;
|
1365
|
+
const uint head_id = tgpig.x % H;
|
1366
|
+
const uint tid = tpitg.x;
|
1367
|
+
|
1368
|
+
if (batch_id >= B || head_id >= H) {
|
1369
|
+
return;
|
1370
|
+
}
|
1371
|
+
|
1372
|
+
const uint state_size = C * head_size;
|
1373
|
+
const uint n_seq_tokens = T / B;
|
1374
|
+
|
1375
|
+
threadgroup float _k[head_size];
|
1376
|
+
threadgroup float _r[head_size];
|
1377
|
+
threadgroup float _tf[head_size];
|
1378
|
+
threadgroup float _td[head_size];
|
1379
|
+
|
1380
|
+
float state[head_size];
|
1381
|
+
|
1382
|
+
for (uint i = 0; i < head_size; i++) {
|
1383
|
+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
1384
|
+
+ i * head_size + tid];
|
1385
|
+
}
|
1386
|
+
|
1387
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1388
|
+
_tf[tid] = tf[head_id * head_size + tid];
|
1389
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1390
|
+
|
1391
|
+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
1392
|
+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
1393
|
+
|
1394
|
+
for (uint t = start_t; t < end_t; t += C) {
|
1395
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1396
|
+
_k[tid] = k[t];
|
1397
|
+
_r[tid] = r[t];
|
1398
|
+
_td[tid] = td[t];
|
1399
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1400
|
+
|
1401
|
+
const float v_val = v[t];
|
1402
|
+
float y = 0.0;
|
1403
|
+
|
1404
|
+
for (uint j = 0; j < head_size; j += 4) {
|
1405
|
+
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
1406
|
+
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
1407
|
+
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
1408
|
+
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
1409
|
+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
1410
|
+
|
1411
|
+
float4 kv = k_vec * v_val;
|
1412
|
+
|
1413
|
+
float4 temp = tf_vec * kv + s_vec;
|
1414
|
+
y += dot(r_vec, temp);
|
1415
|
+
|
1416
|
+
s_vec = s_vec * td_vec + kv;
|
1417
|
+
state[j] = s_vec[0];
|
1418
|
+
state[j+1] = s_vec[1];
|
1419
|
+
state[j+2] = s_vec[2];
|
1420
|
+
state[j+3] = s_vec[3];
|
1421
|
+
}
|
1422
|
+
|
1423
|
+
dst[t] = y;
|
1424
|
+
}
|
1425
|
+
|
1426
|
+
for (uint i = 0; i < head_size; i++) {
|
1427
|
+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
1428
|
+
+ i * head_size + tid] = state[i];
|
1429
|
+
}
|
1430
|
+
}
|
1431
|
+
|
1432
|
+
kernel void kernel_rwkv_wkv7_f32(
|
1433
|
+
device const float * r,
|
1434
|
+
device const float * w,
|
1435
|
+
device const float * k,
|
1436
|
+
device const float * v,
|
1437
|
+
device const float * a,
|
1438
|
+
device const float * b,
|
1439
|
+
device const float * state_in,
|
1440
|
+
device float * dst,
|
1441
|
+
constant uint & B,
|
1442
|
+
constant uint & T,
|
1443
|
+
constant uint & C,
|
1444
|
+
constant uint & H,
|
1445
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1446
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1447
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1448
|
+
|
1449
|
+
const uint head_size = 64; // TODO: support head_size = 128
|
1450
|
+
const uint batch_id = tgpig.x / H;
|
1451
|
+
const uint head_id = tgpig.x % H;
|
1452
|
+
const uint tid = tpitg.x;
|
1453
|
+
|
1454
|
+
if (batch_id >= B || head_id >= H) {
|
1455
|
+
return;
|
1456
|
+
}
|
1457
|
+
|
1458
|
+
const uint state_size = C * head_size;
|
1459
|
+
const uint n_seq_tokens = T / B;
|
1460
|
+
|
1461
|
+
threadgroup float _r[head_size];
|
1462
|
+
threadgroup float _w[head_size];
|
1463
|
+
threadgroup float _k[head_size];
|
1464
|
+
threadgroup float _a[head_size];
|
1465
|
+
threadgroup float _b[head_size];
|
1466
|
+
|
1467
|
+
float state[head_size];
|
1468
|
+
|
1469
|
+
for (uint i = 0; i < head_size; i++) {
|
1470
|
+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
1471
|
+
+ tid * head_size + i];
|
1472
|
+
}
|
1473
|
+
|
1474
|
+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
1475
|
+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
1476
|
+
|
1477
|
+
for (uint t = start_t; t < end_t; t += C) {
|
1478
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1479
|
+
_r[tid] = r[t];
|
1480
|
+
_w[tid] = w[t];
|
1481
|
+
_k[tid] = k[t];
|
1482
|
+
_a[tid] = a[t];
|
1483
|
+
_b[tid] = b[t];
|
1484
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1485
|
+
|
1486
|
+
const float v_val = v[t];
|
1487
|
+
float y = 0.0, sa = 0.0;
|
1488
|
+
|
1489
|
+
float4 sa_vec(0.0);
|
1490
|
+
|
1491
|
+
for (uint j = 0; j < head_size; j += 4) {
|
1492
|
+
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
1493
|
+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
1494
|
+
sa_vec += a_vec * s_vec;
|
1495
|
+
}
|
1496
|
+
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
1497
|
+
|
1498
|
+
for (uint j = 0; j < head_size; j += 4) {
|
1499
|
+
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
1500
|
+
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
1501
|
+
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
1502
|
+
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
1503
|
+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
1504
|
+
|
1505
|
+
float4 kv = k_vec * v_val;
|
1506
|
+
|
1507
|
+
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
1508
|
+
y += dot(s_vec, r_vec);
|
1509
|
+
|
1510
|
+
state[j] = s_vec[0];
|
1511
|
+
state[j+1] = s_vec[1];
|
1512
|
+
state[j+2] = s_vec[2];
|
1513
|
+
state[j+3] = s_vec[3];
|
1514
|
+
}
|
1515
|
+
|
1516
|
+
dst[t] = y;
|
1517
|
+
}
|
1518
|
+
|
1519
|
+
for (uint i = 0; i < head_size; i++) {
|
1520
|
+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
1521
|
+
+ tid * head_size + i] = state[i];
|
1522
|
+
}
|
1523
|
+
}
|
1524
|
+
|
1369
1525
|
kernel void kernel_argmax(
|
1370
1526
|
device const void * x,
|
1371
1527
|
device int32_t * dst,
|
@@ -1534,25 +1690,61 @@ kernel void kernel_rms_norm(
|
|
1534
1690
|
}
|
1535
1691
|
}
|
1536
1692
|
|
1693
|
+
kernel void kernel_l2_norm(
|
1694
|
+
constant ggml_metal_kargs_l2_norm & args,
|
1695
|
+
device const char * src0,
|
1696
|
+
device char * dst,
|
1697
|
+
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
1698
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
1699
|
+
ushort tpitg[[thread_position_in_threadgroup]],
|
1700
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
1701
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
1702
|
+
ushort ntg[[threads_per_threadgroup]]) {
|
1703
|
+
if (sgitg == 0) {
|
1704
|
+
shmem_f32[tiisg] = 0.0f;
|
1705
|
+
}
|
1706
|
+
|
1707
|
+
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
1708
|
+
|
1709
|
+
float sumf = 0.0f;
|
1710
|
+
|
1711
|
+
// parallel sum
|
1712
|
+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
1713
|
+
sumf += dot(x[i00], x[i00]);
|
1714
|
+
}
|
1715
|
+
sumf = simd_sum(sumf);
|
1716
|
+
|
1717
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1718
|
+
|
1719
|
+
if (tiisg == 0) {
|
1720
|
+
shmem_f32[sgitg] = sumf;
|
1721
|
+
}
|
1722
|
+
|
1723
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1724
|
+
|
1725
|
+
sumf = shmem_f32[tiisg];
|
1726
|
+
sumf = simd_sum(sumf);
|
1727
|
+
|
1728
|
+
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
1729
|
+
|
1730
|
+
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
1731
|
+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
1732
|
+
y[i00] = x[i00] * scale;
|
1733
|
+
}
|
1734
|
+
}
|
1735
|
+
|
1537
1736
|
kernel void kernel_group_norm(
|
1538
1737
|
device const float * src0,
|
1539
1738
|
device float * dst,
|
1540
|
-
constant
|
1541
|
-
constant int64_t & ne01,
|
1542
|
-
constant int64_t & ne02,
|
1543
|
-
constant uint64_t & nb00,
|
1544
|
-
constant uint64_t & nb01,
|
1545
|
-
constant uint64_t & nb02,
|
1546
|
-
constant int32_t & n_groups,
|
1547
|
-
constant float & eps,
|
1739
|
+
constant ggml_metal_kargs_group_norm & args,
|
1548
1740
|
threadgroup float * buf [[threadgroup(0)]],
|
1549
1741
|
uint tgpig[[threadgroup_position_in_grid]],
|
1550
1742
|
uint tpitg[[thread_position_in_threadgroup]],
|
1551
1743
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
1552
1744
|
uint tiisg[[thread_index_in_simdgroup]],
|
1553
1745
|
uint ntg[[threads_per_threadgroup]]) {
|
1554
|
-
const int64_t ne = ne00*ne01*ne02;
|
1555
|
-
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
1746
|
+
const int64_t ne = args.ne00*args.ne01*args.ne02;
|
1747
|
+
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
|
1556
1748
|
|
1557
1749
|
int start = tgpig * gs;
|
1558
1750
|
int end = start + gs;
|
@@ -1616,7 +1808,7 @@ kernel void kernel_group_norm(
|
|
1616
1808
|
}
|
1617
1809
|
|
1618
1810
|
const float variance = tmp / gs;
|
1619
|
-
const float scale = 1.0f/sqrt(variance + eps);
|
1811
|
+
const float scale = 1.0f/sqrt(variance + args.eps);
|
1620
1812
|
for (int j = start; j < end; j += ntg) {
|
1621
1813
|
dst[j] *= scale;
|
1622
1814
|
}
|
@@ -1710,14 +1902,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
1710
1902
|
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
1711
1903
|
}
|
1712
1904
|
|
1713
|
-
|
1714
|
-
#define N_DST 4 // each SIMD group works on 4 rows
|
1715
|
-
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
1716
|
-
//Note: This is a template, but strictly speaking it only applies to
|
1717
|
-
// quantizations where the block size is 32. It also does not
|
1718
|
-
// guard against the number of rows not being divisible by
|
1719
|
-
// N_DST, so this is another explicit assumption of the implementation.
|
1720
|
-
template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
|
1905
|
+
template<typename block_q_type, int nr0, int nsg, int nw, typename args_t>
|
1721
1906
|
void mul_vec_q_n_f32_impl(
|
1722
1907
|
args_t args,
|
1723
1908
|
device const char * src0,
|
@@ -1733,7 +1918,7 @@ void mul_vec_q_n_f32_impl(
|
|
1733
1918
|
const int r1 = tgpig.y;
|
1734
1919
|
const int im = tgpig.z;
|
1735
1920
|
|
1736
|
-
const int first_row = (r0 * nsg + sgitg) *
|
1921
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
1737
1922
|
|
1738
1923
|
const uint i12 = im%args.ne12;
|
1739
1924
|
const uint i13 = im/args.ne12;
|
@@ -1745,15 +1930,15 @@ void mul_vec_q_n_f32_impl(
|
|
1745
1930
|
device const float * y = (device const float *) (src1 + offset1);
|
1746
1931
|
|
1747
1932
|
// pointers to src0 rows
|
1748
|
-
device const block_q_type * ax[
|
1749
|
-
for (int row = 0; row <
|
1933
|
+
device const block_q_type * ax[nr0];
|
1934
|
+
for (int row = 0; row < nr0; ++row) {
|
1750
1935
|
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
1751
1936
|
|
1752
1937
|
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
|
1753
1938
|
}
|
1754
1939
|
|
1755
1940
|
float yl[16]; // src1 vector cache
|
1756
|
-
float sumf[
|
1941
|
+
float sumf[nr0] = {0.f};
|
1757
1942
|
|
1758
1943
|
const short ix = (tiisg/2);
|
1759
1944
|
const short il = (tiisg%2)*8;
|
@@ -1765,7 +1950,7 @@ void mul_vec_q_n_f32_impl(
|
|
1765
1950
|
float sumy[2] = { 0.f, 0.f };
|
1766
1951
|
|
1767
1952
|
#pragma unroll
|
1768
|
-
for (
|
1953
|
+
for (short i = 0; i < 8; i += 2) {
|
1769
1954
|
sumy[0] += yb[i + 0] + yb[i + 1];
|
1770
1955
|
yl[i + 0] = yb[i + 0];
|
1771
1956
|
yl[i + 1] = yb[i + 1]/256.f;
|
@@ -1776,7 +1961,7 @@ void mul_vec_q_n_f32_impl(
|
|
1776
1961
|
}
|
1777
1962
|
|
1778
1963
|
#pragma unroll
|
1779
|
-
for (
|
1964
|
+
for (short row = 0; row < nr0; row++) {
|
1780
1965
|
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
|
1781
1966
|
}
|
1782
1967
|
|
@@ -1785,7 +1970,7 @@ void mul_vec_q_n_f32_impl(
|
|
1785
1970
|
|
1786
1971
|
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
1787
1972
|
|
1788
|
-
for (int row = 0; row <
|
1973
|
+
for (int row = 0; row < nr0; ++row) {
|
1789
1974
|
const float tot = simd_sum(sumf[row]);
|
1790
1975
|
|
1791
1976
|
if (tiisg == 0 && first_row + row < args.ne01) {
|
@@ -1802,7 +1987,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
1802
1987
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1803
1988
|
ushort tiisg[[thread_index_in_simdgroup]],
|
1804
1989
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
1805
|
-
mul_vec_q_n_f32_impl<block_q4_0,
|
1990
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
1806
1991
|
}
|
1807
1992
|
|
1808
1993
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -1813,7 +1998,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
1813
1998
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1814
1999
|
ushort tiisg[[thread_index_in_simdgroup]],
|
1815
2000
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
1816
|
-
mul_vec_q_n_f32_impl<block_q4_1,
|
2001
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
1817
2002
|
}
|
1818
2003
|
|
1819
2004
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -1824,7 +2009,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
1824
2009
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1825
2010
|
ushort tiisg[[thread_index_in_simdgroup]],
|
1826
2011
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
1827
|
-
mul_vec_q_n_f32_impl<block_q5_0,
|
2012
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
1828
2013
|
}
|
1829
2014
|
|
1830
2015
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -1835,12 +2020,12 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
1835
2020
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1836
2021
|
ushort tiisg[[thread_index_in_simdgroup]],
|
1837
2022
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
1838
|
-
mul_vec_q_n_f32_impl<block_q5_1,
|
2023
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
1839
2024
|
}
|
1840
2025
|
|
1841
2026
|
#define NB_Q8_0 8
|
1842
2027
|
|
1843
|
-
template<typename args_t>
|
2028
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
1844
2029
|
void kernel_mul_mv_q8_0_f32_impl(
|
1845
2030
|
args_t args,
|
1846
2031
|
device const char * src0,
|
@@ -1850,16 +2035,13 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
1850
2035
|
uint3 tgpig,
|
1851
2036
|
ushort tiisg,
|
1852
2037
|
ushort sgitg) {
|
1853
|
-
const int nr = N_DST;
|
1854
|
-
const int nsg = N_SIMDGROUP;
|
1855
|
-
const int nw = N_SIMDWIDTH;
|
1856
|
-
|
1857
2038
|
const int nb = args.ne00/QK8_0;
|
2039
|
+
|
1858
2040
|
const int r0 = tgpig.x;
|
1859
2041
|
const int r1 = tgpig.y;
|
1860
2042
|
const int im = tgpig.z;
|
1861
2043
|
|
1862
|
-
const int first_row = (r0*nsg + sgitg)*
|
2044
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
1863
2045
|
|
1864
2046
|
const uint i12 = im%args.ne12;
|
1865
2047
|
const uint i13 = im/args.ne12;
|
@@ -1871,15 +2053,15 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
1871
2053
|
device const float * y = (device const float *) (src1 + offset1);
|
1872
2054
|
|
1873
2055
|
// pointers to src0 rows
|
1874
|
-
device const block_q8_0 * ax[
|
1875
|
-
for (int row = 0; row <
|
2056
|
+
device const block_q8_0 * ax[nr0];
|
2057
|
+
for (int row = 0; row < nr0; ++row) {
|
1876
2058
|
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
1877
2059
|
|
1878
2060
|
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
1879
2061
|
}
|
1880
2062
|
|
1881
2063
|
float yl[NB_Q8_0];
|
1882
|
-
float sumf[
|
2064
|
+
float sumf[nr0] = { 0.f };
|
1883
2065
|
|
1884
2066
|
const short ix = tiisg/4;
|
1885
2067
|
const short il = tiisg%4;
|
@@ -1892,7 +2074,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
1892
2074
|
yl[i] = yb[i];
|
1893
2075
|
}
|
1894
2076
|
|
1895
|
-
for (
|
2077
|
+
for (short row = 0; row < nr0; row++) {
|
1896
2078
|
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
|
1897
2079
|
float sumq = 0.f;
|
1898
2080
|
for (short iq = 0; iq < NB_Q8_0; ++iq) {
|
@@ -1906,7 +2088,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
1906
2088
|
|
1907
2089
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
1908
2090
|
|
1909
|
-
for (int row = 0; row <
|
2091
|
+
for (int row = 0; row < nr0; ++row) {
|
1910
2092
|
const float tot = simd_sum(sumf[row]);
|
1911
2093
|
|
1912
2094
|
if (tiisg == 0 && first_row + row < args.ne01) {
|
@@ -1924,7 +2106,7 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
1924
2106
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1925
2107
|
ushort tiisg[[thread_index_in_simdgroup]],
|
1926
2108
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
1927
|
-
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
2109
|
+
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
1928
2110
|
}
|
1929
2111
|
|
1930
2112
|
// mat-vec kernel processing in chunks of float4
|
@@ -2261,9 +2443,9 @@ void kernel_mul_mv_impl(
|
|
2261
2443
|
sumf += (T0) x[i] * (T1) y[i];
|
2262
2444
|
}
|
2263
2445
|
|
2264
|
-
float
|
2446
|
+
float sum_all = simd_sum(sumf);
|
2265
2447
|
if (tiisg == 0) {
|
2266
|
-
dst_f32[(uint64_t)r1*args.ne0 + r0] =
|
2448
|
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
2267
2449
|
}
|
2268
2450
|
}
|
2269
2451
|
} else {
|
@@ -2284,10 +2466,10 @@ void kernel_mul_mv_impl(
|
|
2284
2466
|
sumf += dot((float4) x4[i], (float4) y4[i]);
|
2285
2467
|
}
|
2286
2468
|
|
2287
|
-
float
|
2469
|
+
float sum_all = simd_sum(sumf);
|
2288
2470
|
if (tiisg == 0) {
|
2289
|
-
for (int i = 4*(args.ne00/4); i < args.ne00; ++i)
|
2290
|
-
dst_f32[(uint64_t)r1*args.ne0 + r0] =
|
2471
|
+
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
|
2472
|
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
2291
2473
|
}
|
2292
2474
|
}
|
2293
2475
|
}
|
@@ -2349,9 +2531,9 @@ kernel void kernel_mul_mv_1row(
|
|
2349
2531
|
for (int i = tiisg; i < args.ne00; i += 32) {
|
2350
2532
|
sumf += (float) x[i] * (float) y[i];
|
2351
2533
|
}
|
2352
|
-
float
|
2534
|
+
float sum_all = simd_sum(sumf);
|
2353
2535
|
if (tiisg == 0) {
|
2354
|
-
dst_f32[r0] =
|
2536
|
+
dst_f32[r0] = sum_all;
|
2355
2537
|
}
|
2356
2538
|
} else {
|
2357
2539
|
device const T4 * x4 = (device const T4 *) x;
|
@@ -2361,11 +2543,11 @@ kernel void kernel_mul_mv_1row(
|
|
2361
2543
|
sumf += dot((float4) x4[i], y4[i]);
|
2362
2544
|
}
|
2363
2545
|
|
2364
|
-
float
|
2546
|
+
float sum_all = simd_sum(sumf);
|
2365
2547
|
|
2366
2548
|
if (tiisg == 0) {
|
2367
|
-
for (int i = 4*(args.ne00/4); i < args.ne00; ++i)
|
2368
|
-
dst_f32[r0] =
|
2549
|
+
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
|
2550
|
+
dst_f32[r0] = sum_all;
|
2369
2551
|
}
|
2370
2552
|
}
|
2371
2553
|
}
|
@@ -2410,9 +2592,9 @@ kernel void kernel_mul_mv_l4(
|
|
2410
2592
|
sumf += dot((float4) x4[i], y4[i]);
|
2411
2593
|
}
|
2412
2594
|
|
2413
|
-
float
|
2595
|
+
float sum_all = simd_sum(sumf);
|
2414
2596
|
if (tiisg == 0) {
|
2415
|
-
dst_f32[(uint64_t)r1*args.ne0 + r0] =
|
2597
|
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
2416
2598
|
}
|
2417
2599
|
}
|
2418
2600
|
}
|
@@ -2568,8 +2750,148 @@ kernel void kernel_rope_neox(
|
|
2568
2750
|
}
|
2569
2751
|
}
|
2570
2752
|
|
2753
|
+
template<typename T>
|
2754
|
+
kernel void kernel_rope_multi(
|
2755
|
+
constant ggml_metal_kargs_rope & args,
|
2756
|
+
device const char * src0,
|
2757
|
+
device const char * src1,
|
2758
|
+
device const char * src2,
|
2759
|
+
device char * dst,
|
2760
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
2761
|
+
ushort3 tptg [[threads_per_threadgroup]],
|
2762
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
2763
|
+
const int i3 = tgpig[2];
|
2764
|
+
const int i2 = tgpig[1];
|
2765
|
+
const int i1 = tgpig[0];
|
2766
|
+
|
2767
|
+
float corr_dims[2];
|
2768
|
+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
2769
|
+
|
2770
|
+
device const int32_t * pos = (device const int32_t *) src1;
|
2771
|
+
|
2772
|
+
const float inv_ndims = -1.f/args.n_dims;
|
2773
|
+
|
2774
|
+
float cos_theta;
|
2775
|
+
float sin_theta;
|
2776
|
+
|
2777
|
+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
2778
|
+
if (i0 < args.n_dims) {
|
2779
|
+
const int ic = i0/2;
|
2780
|
+
|
2781
|
+
// mrope theta calculations
|
2782
|
+
// note: the rest is the same as kernel_rope_neox
|
2783
|
+
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
|
2784
|
+
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
|
2785
|
+
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
2786
|
+
const int sector = ic % sect_dims;
|
2787
|
+
|
2788
|
+
float theta_base;
|
2789
|
+
if (sector < args.sect_0) {
|
2790
|
+
theta_base = (float) pos[i2];
|
2791
|
+
} else if (sector < sec_w01) {
|
2792
|
+
theta_base = (float) pos[i2 + args.ne02];
|
2793
|
+
} else if (sector < sec_w012) {
|
2794
|
+
theta_base = (float) pos[i2 + args.ne02 * 2];
|
2795
|
+
} else {
|
2796
|
+
theta_base = (float) pos[i2 + args.ne02 * 3];
|
2797
|
+
}
|
2798
|
+
// end of mrope
|
2799
|
+
|
2800
|
+
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
2801
|
+
|
2802
|
+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
2803
|
+
|
2804
|
+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
2805
|
+
|
2806
|
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
2807
|
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
2808
|
+
|
2809
|
+
const float x0 = src[0];
|
2810
|
+
const float x1 = src[args.n_dims/2];
|
2811
|
+
|
2812
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
2813
|
+
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
2814
|
+
} else {
|
2815
|
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
2816
|
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
2817
|
+
|
2818
|
+
dst_data[0] = src[0];
|
2819
|
+
dst_data[1] = src[1];
|
2820
|
+
}
|
2821
|
+
}
|
2822
|
+
}
|
2823
|
+
|
2824
|
+
template<typename T>
|
2825
|
+
kernel void kernel_rope_vision(
|
2826
|
+
constant ggml_metal_kargs_rope & args,
|
2827
|
+
device const char * src0,
|
2828
|
+
device const char * src1,
|
2829
|
+
device const char * src2,
|
2830
|
+
device char * dst,
|
2831
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
2832
|
+
ushort3 tptg [[threads_per_threadgroup]],
|
2833
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
2834
|
+
const int i3 = tgpig[2];
|
2835
|
+
const int i2 = tgpig[1];
|
2836
|
+
const int i1 = tgpig[0];
|
2837
|
+
|
2838
|
+
float corr_dims[2];
|
2839
|
+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
2840
|
+
|
2841
|
+
device const int32_t * pos = (device const int32_t *) src1;
|
2842
|
+
|
2843
|
+
const float inv_ndims = -1.f/args.n_dims;
|
2844
|
+
|
2845
|
+
float cos_theta;
|
2846
|
+
float sin_theta;
|
2847
|
+
|
2848
|
+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
2849
|
+
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
|
2850
|
+
const int ic = i0/2;
|
2851
|
+
|
2852
|
+
// mrope theta calculations (only support 2 dimensions)
|
2853
|
+
const int sect_dims = args.sect_0 + args.sect_1;
|
2854
|
+
const int sector = ic % sect_dims;
|
2855
|
+
|
2856
|
+
float p;
|
2857
|
+
float theta_base;
|
2858
|
+
if (sector < args.sect_1) {
|
2859
|
+
p = (float) sector;
|
2860
|
+
theta_base = (float) pos[i2];
|
2861
|
+
} else {
|
2862
|
+
p = (float) sector - args.sect_0;
|
2863
|
+
theta_base = (float) pos[i2 + args.ne02];
|
2864
|
+
}
|
2865
|
+
|
2866
|
+
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
2867
|
+
// end of mrope
|
2868
|
+
|
2869
|
+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
2870
|
+
|
2871
|
+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
2872
|
+
|
2873
|
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
2874
|
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
2875
|
+
|
2876
|
+
const float x0 = src[0];
|
2877
|
+
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
|
2878
|
+
|
2879
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
2880
|
+
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
|
2881
|
+
} else {
|
2882
|
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
2883
|
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
2884
|
+
|
2885
|
+
dst_data[0] = src[0];
|
2886
|
+
dst_data[1] = src[1];
|
2887
|
+
}
|
2888
|
+
}
|
2889
|
+
}
|
2890
|
+
|
2571
2891
|
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
2572
2892
|
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
2893
|
+
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
2894
|
+
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
2573
2895
|
|
2574
2896
|
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
2575
2897
|
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
@@ -2577,20 +2899,16 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
|
2577
2899
|
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
2578
2900
|
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
2579
2901
|
|
2902
|
+
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
2903
|
+
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
2904
|
+
|
2905
|
+
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
2906
|
+
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
2907
|
+
|
2580
2908
|
typedef void (im2col_t)(
|
2581
2909
|
device const float * x,
|
2582
2910
|
device char * dst,
|
2583
|
-
constant
|
2584
|
-
constant int32_t & ofs1,
|
2585
|
-
constant int32_t & IW,
|
2586
|
-
constant int32_t & IH,
|
2587
|
-
constant int32_t & CHW,
|
2588
|
-
constant int32_t & s0,
|
2589
|
-
constant int32_t & s1,
|
2590
|
-
constant int32_t & p0,
|
2591
|
-
constant int32_t & p1,
|
2592
|
-
constant int32_t & d0,
|
2593
|
-
constant int32_t & d1,
|
2911
|
+
constant ggml_metal_kargs_im2col & args,
|
2594
2912
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2595
2913
|
uint3 tgpg[[threadgroups_per_grid]],
|
2596
2914
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2600,17 +2918,7 @@ template <typename T>
|
|
2600
2918
|
kernel void kernel_im2col(
|
2601
2919
|
device const float * x,
|
2602
2920
|
device char * dst,
|
2603
|
-
constant
|
2604
|
-
constant int32_t & ofs1,
|
2605
|
-
constant int32_t & IW,
|
2606
|
-
constant int32_t & IH,
|
2607
|
-
constant int32_t & CHW,
|
2608
|
-
constant int32_t & s0,
|
2609
|
-
constant int32_t & s1,
|
2610
|
-
constant int32_t & p0,
|
2611
|
-
constant int32_t & p1,
|
2612
|
-
constant int32_t & d0,
|
2613
|
-
constant int32_t & d1,
|
2921
|
+
constant ggml_metal_kargs_im2col & args,
|
2614
2922
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2615
2923
|
uint3 tgpg[[threadgroups_per_grid]],
|
2616
2924
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2631,17 +2939,17 @@ kernel void kernel_im2col(
|
|
2631
2939
|
const int64_t ioh = tgpig[1];
|
2632
2940
|
const int64_t iow = tgpig[2];
|
2633
2941
|
|
2634
|
-
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
2635
|
-
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
2942
|
+
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
|
2943
|
+
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
|
2636
2944
|
|
2637
|
-
const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
|
2945
|
+
const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
|
2638
2946
|
|
2639
2947
|
device T * pdst = (device T *) (dst);
|
2640
2948
|
|
2641
|
-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
2949
|
+
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
2642
2950
|
pdst[offset_dst] = 0.0f;
|
2643
2951
|
} else {
|
2644
|
-
const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
|
2952
|
+
const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
|
2645
2953
|
pdst[offset_dst] = x[offset_src];
|
2646
2954
|
}
|
2647
2955
|
}
|
@@ -2652,20 +2960,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
2652
2960
|
typedef void (im2col_ext_t)(
|
2653
2961
|
device const float * x,
|
2654
2962
|
device char * dst,
|
2655
|
-
constant
|
2656
|
-
constant int32_t & ofs1,
|
2657
|
-
constant int32_t & IW,
|
2658
|
-
constant int32_t & IH,
|
2659
|
-
constant int32_t & CHW,
|
2660
|
-
constant int32_t & s0,
|
2661
|
-
constant int32_t & s1,
|
2662
|
-
constant int32_t & p0,
|
2663
|
-
constant int32_t & p1,
|
2664
|
-
constant int32_t & d0,
|
2665
|
-
constant int32_t & d1,
|
2666
|
-
constant int32_t & N,
|
2667
|
-
constant int32_t & KH,
|
2668
|
-
constant int32_t & KW,
|
2963
|
+
constant ggml_metal_kargs_im2col & args,
|
2669
2964
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2670
2965
|
uint3 tgpg[[threadgroups_per_grid]],
|
2671
2966
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2675,53 +2970,40 @@ template <typename T>
|
|
2675
2970
|
kernel void kernel_im2col_ext(
|
2676
2971
|
device const float * x,
|
2677
2972
|
device char * dst,
|
2678
|
-
constant
|
2679
|
-
constant int32_t & ofs1,
|
2680
|
-
constant int32_t & IW,
|
2681
|
-
constant int32_t & IH,
|
2682
|
-
constant int32_t & CHW,
|
2683
|
-
constant int32_t & s0,
|
2684
|
-
constant int32_t & s1,
|
2685
|
-
constant int32_t & p0,
|
2686
|
-
constant int32_t & p1,
|
2687
|
-
constant int32_t & d0,
|
2688
|
-
constant int32_t & d1,
|
2689
|
-
constant int32_t & N,
|
2690
|
-
constant int32_t & KH,
|
2691
|
-
constant int32_t & KW,
|
2973
|
+
constant ggml_metal_kargs_im2col & args,
|
2692
2974
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2693
2975
|
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
2694
2976
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
2695
2977
|
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
2696
|
-
const int64_t KHW =
|
2978
|
+
const int64_t KHW = (int64_t)args.KHW;
|
2697
2979
|
|
2698
|
-
const int64_t d = tgpig[0] / CHW;
|
2699
|
-
const int64_t chw = tgpig[0] % CHW;
|
2980
|
+
const int64_t d = tgpig[0] / args.CHW;
|
2981
|
+
const int64_t chw = tgpig[0] % args.CHW;
|
2700
2982
|
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
2701
2983
|
const int64_t HW = tgpig[0] % KHW;
|
2702
2984
|
|
2703
2985
|
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
2704
|
-
if (tpitg_0 >= N) {
|
2986
|
+
if (tpitg_0 >= args.N) {
|
2705
2987
|
return;
|
2706
2988
|
}
|
2707
2989
|
|
2708
|
-
const int64_t tpitg_1 = HW / KW;
|
2709
|
-
const int64_t tpitg_2 = HW % KW;
|
2990
|
+
const int64_t tpitg_1 = HW / args.KW;
|
2991
|
+
const int64_t tpitg_2 = HW % args.KW;
|
2710
2992
|
|
2711
|
-
const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
|
2712
|
-
const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
|
2993
|
+
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
|
2994
|
+
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
|
2713
2995
|
|
2714
2996
|
const int64_t offset_dst =
|
2715
|
-
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
2716
|
-
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
|
2997
|
+
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
|
2998
|
+
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
|
2717
2999
|
|
2718
3000
|
device T * pdst = (device T *) (dst);
|
2719
3001
|
|
2720
|
-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
3002
|
+
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
2721
3003
|
pdst[offset_dst] = 0.0f;
|
2722
3004
|
} else {
|
2723
|
-
const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
|
2724
|
-
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
3005
|
+
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
|
3006
|
+
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
|
2725
3007
|
}
|
2726
3008
|
}
|
2727
3009
|
|
@@ -2732,12 +3014,7 @@ typedef void (conv_transpose_1d_t)(
|
|
2732
3014
|
device const float * src0,
|
2733
3015
|
device const float * src1,
|
2734
3016
|
device char * dst,
|
2735
|
-
constant
|
2736
|
-
constant int32_t & IL,
|
2737
|
-
constant int32_t & K,
|
2738
|
-
constant int32_t & s0,
|
2739
|
-
constant uint64_t & nb0,
|
2740
|
-
constant uint64_t & nb1,
|
3017
|
+
constant ggml_metal_kargs_conv_transpose_1d & args,
|
2741
3018
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2742
3019
|
uint3 tgpg[[threadgroups_per_grid]]);
|
2743
3020
|
|
@@ -2746,29 +3023,24 @@ kernel void kernel_conv_transpose_1d(
|
|
2746
3023
|
device const T * src0,
|
2747
3024
|
device const float * src1,
|
2748
3025
|
device char * dst,
|
2749
|
-
constant
|
2750
|
-
constant int32_t & IL,
|
2751
|
-
constant int32_t & K,
|
2752
|
-
constant int32_t & s0,
|
2753
|
-
constant uint64_t & nb0,
|
2754
|
-
constant uint64_t & nb1,
|
3026
|
+
constant ggml_metal_kargs_conv_transpose_1d & args,
|
2755
3027
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2756
3028
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
2757
3029
|
|
2758
3030
|
float v = 0.0f;
|
2759
3031
|
|
2760
|
-
for (int64_t c = 0; c < IC; c++) {
|
2761
|
-
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
|
2762
|
-
const int32_t input_offset = c * IL;
|
3032
|
+
for (int64_t c = 0; c < args.IC; c++) {
|
3033
|
+
const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
|
3034
|
+
const int32_t input_offset = c * args.IL;
|
2763
3035
|
|
2764
|
-
for (int64_t i = 0; i < IL; i++) {
|
2765
|
-
if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
|
2766
|
-
v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
|
3036
|
+
for (int64_t i = 0; i < args.IL; i++) {
|
3037
|
+
if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
|
3038
|
+
v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
|
2767
3039
|
}
|
2768
3040
|
}
|
2769
3041
|
}
|
2770
3042
|
|
2771
|
-
device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
|
3043
|
+
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
|
2772
3044
|
|
2773
3045
|
dst_ptr[0] = v;
|
2774
3046
|
}
|
@@ -2778,12 +3050,7 @@ kernel void kernel_conv_transpose_1d<float>(
|
|
2778
3050
|
device const float * src0,
|
2779
3051
|
device const float * src1,
|
2780
3052
|
device char * dst,
|
2781
|
-
constant
|
2782
|
-
constant int32_t & IL,
|
2783
|
-
constant int32_t & K,
|
2784
|
-
constant int32_t & s0,
|
2785
|
-
constant uint64_t & nb0,
|
2786
|
-
constant uint64_t & nb1,
|
3053
|
+
constant ggml_metal_kargs_conv_transpose_1d & args,
|
2787
3054
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2788
3055
|
uint3 tgpg[[threadgroups_per_grid]]);
|
2789
3056
|
|
@@ -2792,38 +3059,14 @@ kernel void kernel_conv_transpose_1d<half>(
|
|
2792
3059
|
device const half * src0,
|
2793
3060
|
device const float * src1,
|
2794
3061
|
device char * dst,
|
2795
|
-
constant
|
2796
|
-
constant int32_t & IL,
|
2797
|
-
constant int32_t & K,
|
2798
|
-
constant int32_t & s0,
|
2799
|
-
constant uint64_t & nb0,
|
2800
|
-
constant uint64_t & nb1,
|
3062
|
+
constant ggml_metal_kargs_conv_transpose_1d & args,
|
2801
3063
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2802
3064
|
uint3 tgpg[[threadgroups_per_grid]]);
|
2803
3065
|
|
2804
3066
|
kernel void kernel_upscale_f32(
|
2805
3067
|
device const char * src0,
|
2806
3068
|
device char * dst,
|
2807
|
-
constant
|
2808
|
-
constant int64_t & ne01,
|
2809
|
-
constant int64_t & ne02,
|
2810
|
-
constant int64_t & ne03,
|
2811
|
-
constant uint64_t & nb00,
|
2812
|
-
constant uint64_t & nb01,
|
2813
|
-
constant uint64_t & nb02,
|
2814
|
-
constant uint64_t & nb03,
|
2815
|
-
constant int64_t & ne0,
|
2816
|
-
constant int64_t & ne1,
|
2817
|
-
constant int64_t & ne2,
|
2818
|
-
constant int64_t & ne3,
|
2819
|
-
constant uint64_t & nb0,
|
2820
|
-
constant uint64_t & nb1,
|
2821
|
-
constant uint64_t & nb2,
|
2822
|
-
constant uint64_t & nb3,
|
2823
|
-
constant float & sf0,
|
2824
|
-
constant float & sf1,
|
2825
|
-
constant float & sf2,
|
2826
|
-
constant float & sf3,
|
3069
|
+
constant ggml_metal_kargs_upscale & args,
|
2827
3070
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2828
3071
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
2829
3072
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -2832,15 +3075,15 @@ kernel void kernel_upscale_f32(
|
|
2832
3075
|
const int64_t i2 = tgpig.y;
|
2833
3076
|
const int64_t i1 = tgpig.x;
|
2834
3077
|
|
2835
|
-
const int64_t i03 = i3/sf3;
|
2836
|
-
const int64_t i02 = i2/sf2;
|
2837
|
-
const int64_t i01 = i1/sf1;
|
3078
|
+
const int64_t i03 = i3/args.sf3;
|
3079
|
+
const int64_t i02 = i2/args.sf2;
|
3080
|
+
const int64_t i01 = i1/args.sf1;
|
2838
3081
|
|
2839
|
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
2840
|
-
const int64_t i00 = i0/sf0;
|
3082
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
3083
|
+
const int64_t i00 = i0/args.sf0;
|
2841
3084
|
|
2842
|
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2843
|
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
3085
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
3086
|
+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
2844
3087
|
|
2845
3088
|
dst_ptr[0] = src0_ptr[0];
|
2846
3089
|
}
|
@@ -2849,22 +3092,7 @@ kernel void kernel_upscale_f32(
|
|
2849
3092
|
kernel void kernel_pad_f32(
|
2850
3093
|
device const char * src0,
|
2851
3094
|
device char * dst,
|
2852
|
-
constant
|
2853
|
-
constant int64_t & ne01,
|
2854
|
-
constant int64_t & ne02,
|
2855
|
-
constant int64_t & ne03,
|
2856
|
-
constant uint64_t & nb00,
|
2857
|
-
constant uint64_t & nb01,
|
2858
|
-
constant uint64_t & nb02,
|
2859
|
-
constant uint64_t & nb03,
|
2860
|
-
constant int64_t & ne0,
|
2861
|
-
constant int64_t & ne1,
|
2862
|
-
constant int64_t & ne2,
|
2863
|
-
constant int64_t & ne3,
|
2864
|
-
constant uint64_t & nb0,
|
2865
|
-
constant uint64_t & nb1,
|
2866
|
-
constant uint64_t & nb2,
|
2867
|
-
constant uint64_t & nb3,
|
3095
|
+
constant ggml_metal_kargs_pad & args,
|
2868
3096
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2869
3097
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
2870
3098
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -2877,12 +3105,12 @@ kernel void kernel_pad_f32(
|
|
2877
3105
|
const int64_t i02 = i2;
|
2878
3106
|
const int64_t i01 = i1;
|
2879
3107
|
|
2880
|
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
2881
|
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
3108
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
3109
|
+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
2882
3110
|
|
2883
|
-
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
2884
|
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
2885
|
-
if (i0 < ne00) {
|
3111
|
+
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
3112
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
3113
|
+
if (i0 < args.ne00) {
|
2886
3114
|
dst_ptr[i0] = src0_ptr[i0];
|
2887
3115
|
} else {
|
2888
3116
|
dst_ptr[i0] = 0.0f;
|
@@ -2892,7 +3120,7 @@ kernel void kernel_pad_f32(
|
|
2892
3120
|
return;
|
2893
3121
|
}
|
2894
3122
|
|
2895
|
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
3123
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
2896
3124
|
dst_ptr[i0] = 0.0f;
|
2897
3125
|
}
|
2898
3126
|
}
|
@@ -2900,21 +3128,7 @@ kernel void kernel_pad_f32(
|
|
2900
3128
|
kernel void kernel_pad_reflect_1d_f32(
|
2901
3129
|
device const char * src0,
|
2902
3130
|
device char * dst,
|
2903
|
-
constant
|
2904
|
-
constant int64_t & ne01,
|
2905
|
-
constant int64_t & ne02,
|
2906
|
-
constant int64_t & ne03,
|
2907
|
-
constant int64_t & ne0,
|
2908
|
-
constant uint64_t & nb00,
|
2909
|
-
constant uint64_t & nb01,
|
2910
|
-
constant uint64_t & nb02,
|
2911
|
-
constant uint64_t & nb03,
|
2912
|
-
constant uint64_t & nb0,
|
2913
|
-
constant uint64_t & nb1,
|
2914
|
-
constant uint64_t & nb2,
|
2915
|
-
constant uint64_t & nb3,
|
2916
|
-
constant int32_t & p0,
|
2917
|
-
constant int32_t & p1,
|
3131
|
+
constant ggml_metal_kargs_pad_reflect_1d & args,
|
2918
3132
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2919
3133
|
uint3 tgpg[[threadgroups_per_grid]],
|
2920
3134
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2928,17 +3142,17 @@ kernel void kernel_pad_reflect_1d_f32(
|
|
2928
3142
|
const int64_t i02 = i2;
|
2929
3143
|
const int64_t i01 = i1;
|
2930
3144
|
|
2931
|
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
2932
|
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
3145
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
3146
|
+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
2933
3147
|
|
2934
|
-
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
2935
|
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
2936
|
-
if (i0 < p0) {
|
2937
|
-
dst_ptr[i0] = src0_ptr[p0 - i0];
|
2938
|
-
} else if (i0 < ne0 - p1) {
|
2939
|
-
dst_ptr[i0] = src0_ptr[i0 - p0];
|
3148
|
+
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
3149
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
3150
|
+
if (i0 < args.p0) {
|
3151
|
+
dst_ptr[i0] = src0_ptr[args.p0 - i0];
|
3152
|
+
} else if (i0 < args.ne0 - args.p1) {
|
3153
|
+
dst_ptr[i0] = src0_ptr[i0 - args.p0];
|
2940
3154
|
} else {
|
2941
|
-
dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
|
3155
|
+
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
|
2942
3156
|
}
|
2943
3157
|
}
|
2944
3158
|
}
|
@@ -2946,44 +3160,40 @@ kernel void kernel_pad_reflect_1d_f32(
|
|
2946
3160
|
|
2947
3161
|
kernel void kernel_arange_f32(
|
2948
3162
|
device char * dst,
|
2949
|
-
constant
|
2950
|
-
constant float & start,
|
2951
|
-
constant float & step,
|
3163
|
+
constant ggml_metal_kargs_arange & args,
|
2952
3164
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2953
3165
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
2954
3166
|
uint3 ntg[[threads_per_threadgroup]]) {
|
2955
3167
|
|
2956
3168
|
device float * dst_ptr = (device float *) dst;
|
2957
3169
|
|
2958
|
-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
2959
|
-
dst_ptr[i0] = start + step * i0;
|
3170
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
3171
|
+
dst_ptr[i0] = args.start + args.step * i0;
|
2960
3172
|
}
|
2961
3173
|
}
|
2962
3174
|
|
2963
3175
|
kernel void kernel_timestep_embedding_f32(
|
2964
3176
|
device const char * src0,
|
2965
3177
|
device char * dst,
|
2966
|
-
constant
|
2967
|
-
constant int & dim,
|
2968
|
-
constant int & max_period,
|
3178
|
+
constant ggml_metal_kargs_timestep_embedding & args,
|
2969
3179
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2970
3180
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
2971
3181
|
uint3 ntg[[threads_per_threadgroup]]) {
|
2972
3182
|
|
2973
3183
|
int i = tgpig.x;
|
2974
|
-
device float * embed_data = (device float *)(dst +
|
3184
|
+
device float * embed_data = (device float *)(dst + i*args.nb1);
|
2975
3185
|
|
2976
|
-
int half_ = dim / 2;
|
3186
|
+
int half_ = args.dim / 2;
|
2977
3187
|
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
2978
3188
|
float timestep = ((device float *)src0)[i];
|
2979
|
-
float freq = (float)exp(-log((float)max_period) * j / half_);
|
3189
|
+
float freq = (float)exp(-log((float)args.max_period) * j / half_);
|
2980
3190
|
float arg = timestep * freq;
|
2981
3191
|
embed_data[j ] = cos(arg);
|
2982
3192
|
embed_data[j + half_] = sin(arg);
|
2983
3193
|
}
|
2984
3194
|
|
2985
|
-
if (dim % 2 != 0 && tpitg.x == 0) {
|
2986
|
-
embed_data[dim] = 0.f;
|
3195
|
+
if (args.dim % 2 != 0 && tpitg.x == 0) {
|
3196
|
+
embed_data[args.dim] = 0.f;
|
2987
3197
|
}
|
2988
3198
|
}
|
2989
3199
|
|
@@ -2991,8 +3201,7 @@ kernel void kernel_timestep_embedding_f32(
|
|
2991
3201
|
typedef void (argsort_t)(
|
2992
3202
|
device const float * x,
|
2993
3203
|
device int32_t * dst,
|
2994
|
-
constant
|
2995
|
-
constant int64_t & ncols_pad,
|
3204
|
+
constant ggml_metal_kargs_argsort & args,
|
2996
3205
|
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
2997
3206
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2998
3207
|
uint3 tpitg[[thread_position_in_threadgroup]]);
|
@@ -3001,8 +3210,7 @@ template<ggml_sort_order order>
|
|
3001
3210
|
kernel void kernel_argsort_f32_i32(
|
3002
3211
|
device const float * x,
|
3003
3212
|
device int32_t * dst,
|
3004
|
-
constant
|
3005
|
-
constant int64_t & ncols_pad,
|
3213
|
+
constant ggml_metal_kargs_argsort & args,
|
3006
3214
|
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
3007
3215
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3008
3216
|
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
@@ -3010,9 +3218,9 @@ kernel void kernel_argsort_f32_i32(
|
|
3010
3218
|
int col = tpitg[0];
|
3011
3219
|
int row = tgpig[1];
|
3012
3220
|
|
3013
|
-
if (col >= ncols_pad) return;
|
3221
|
+
if (col >= args.ncols_pad) return;
|
3014
3222
|
|
3015
|
-
device const float * x_row = x + row * ncols;
|
3223
|
+
device const float * x_row = x + row * args.ncols;
|
3016
3224
|
threadgroup int32_t * dst_row = shared_values;
|
3017
3225
|
|
3018
3226
|
// initialize indices
|
@@ -3020,21 +3228,21 @@ kernel void kernel_argsort_f32_i32(
|
|
3020
3228
|
|
3021
3229
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3022
3230
|
|
3023
|
-
for (int k = 2; k <= ncols_pad; k *= 2) {
|
3231
|
+
for (int k = 2; k <= args.ncols_pad; k *= 2) {
|
3024
3232
|
for (int j = k / 2; j > 0; j /= 2) {
|
3025
3233
|
int ixj = col ^ j;
|
3026
3234
|
if (ixj > col) {
|
3027
3235
|
if ((col & k) == 0) {
|
3028
|
-
if (dst_row[col] >= ncols ||
|
3029
|
-
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
3236
|
+
if (dst_row[col] >= args.ncols ||
|
3237
|
+
(dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
3030
3238
|
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
3031
3239
|
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
3032
3240
|
) {
|
3033
3241
|
SWAP(dst_row[col], dst_row[ixj]);
|
3034
3242
|
}
|
3035
3243
|
} else {
|
3036
|
-
if (dst_row[ixj] >= ncols ||
|
3037
|
-
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
3244
|
+
if (dst_row[ixj] >= args.ncols ||
|
3245
|
+
(dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
3038
3246
|
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
3039
3247
|
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
3040
3248
|
) {
|
@@ -3047,8 +3255,8 @@ kernel void kernel_argsort_f32_i32(
|
|
3047
3255
|
}
|
3048
3256
|
|
3049
3257
|
// copy the result to dst without the padding
|
3050
|
-
if (col < ncols) {
|
3051
|
-
dst[row * ncols + col] = dst_row[col];
|
3258
|
+
if (col < args.ncols) {
|
3259
|
+
dst[row * args.ncols + col] = dst_row[col];
|
3052
3260
|
}
|
3053
3261
|
}
|
3054
3262
|
|
@@ -3058,9 +3266,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
|
|
3058
3266
|
kernel void kernel_leaky_relu_f32(
|
3059
3267
|
device const float * src0,
|
3060
3268
|
device float * dst,
|
3061
|
-
constant
|
3269
|
+
constant ggml_metal_kargs_leaky_relu & args,
|
3062
3270
|
uint tpig[[thread_position_in_grid]]) {
|
3063
|
-
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
3271
|
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
|
3064
3272
|
}
|
3065
3273
|
|
3066
3274
|
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
@@ -3084,10 +3292,11 @@ template<
|
|
3084
3292
|
typename kd4x4_t, // key type in device memory
|
3085
3293
|
short nl_k,
|
3086
3294
|
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
|
3087
|
-
typename vd4x4_t, //
|
3295
|
+
typename vd4x4_t, // value type in device memory
|
3088
3296
|
short nl_v,
|
3089
3297
|
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
3090
|
-
short
|
3298
|
+
short DK, // K head size
|
3299
|
+
short DV, // V head size
|
3091
3300
|
short Q = 8, // queries per threadgroup
|
3092
3301
|
short KV = 8, // key/value processed per each simdgroup
|
3093
3302
|
short C = 32> // cache items per threadgroup
|
@@ -3109,20 +3318,24 @@ kernel void kernel_flash_attn_ext(
|
|
3109
3318
|
const int iq2 = tgpig[1];
|
3110
3319
|
const int iq1 = tgpig[0]*Q;
|
3111
3320
|
|
3112
|
-
|
3113
|
-
|
3114
|
-
|
3115
|
-
|
3116
|
-
|
3321
|
+
constexpr short DK4 = DK/4;
|
3322
|
+
constexpr short DK8 = DK/8;
|
3323
|
+
constexpr short DK16 = DK/16;
|
3324
|
+
constexpr short DV4 = DV/4;
|
3325
|
+
constexpr short DV8 = DV/8;
|
3326
|
+
constexpr short DV16 = DV/16;
|
3327
|
+
|
3328
|
+
constexpr short NW = N_SIMDWIDTH;
|
3329
|
+
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
3117
3330
|
|
3118
3331
|
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
3119
|
-
const short T =
|
3332
|
+
const short T = DK + 2*TS; // shared memory size per query in (half)
|
3120
3333
|
|
3121
|
-
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*
|
3122
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*
|
3123
|
-
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*
|
3124
|
-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*
|
3125
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*
|
3334
|
+
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
3335
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
3336
|
+
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
3337
|
+
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
3338
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
3126
3339
|
|
3127
3340
|
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
3128
3341
|
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
@@ -3131,23 +3344,23 @@ kernel void kernel_flash_attn_ext(
|
|
3131
3344
|
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
3132
3345
|
|
3133
3346
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
3134
|
-
o8x8_t lo[
|
3347
|
+
o8x8_t lo[DV8];
|
3135
3348
|
|
3136
3349
|
// load heads from Q to shared memory
|
3137
3350
|
for (short j = sgitg; j < Q; j += nsg) {
|
3138
3351
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
3139
3352
|
|
3140
|
-
for (short i = tiisg; i <
|
3353
|
+
for (short i = tiisg; i < DK4; i += NW) {
|
3141
3354
|
if (iq1 + j < args.ne01) {
|
3142
|
-
sq4[j*
|
3355
|
+
sq4[j*DK4 + i] = (q4_t) q4[i];
|
3143
3356
|
} else {
|
3144
|
-
sq4[j*
|
3357
|
+
sq4[j*DK4 + i] = (q4_t) 0.0f;
|
3145
3358
|
}
|
3146
3359
|
}
|
3147
3360
|
}
|
3148
3361
|
|
3149
3362
|
// zero out lo
|
3150
|
-
for (short i = 0; i <
|
3363
|
+
for (short i = 0; i < DV8; ++i) {
|
3151
3364
|
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
|
3152
3365
|
}
|
3153
3366
|
|
@@ -3161,8 +3374,8 @@ kernel void kernel_flash_attn_ext(
|
|
3161
3374
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3162
3375
|
|
3163
3376
|
{
|
3164
|
-
|
3165
|
-
|
3377
|
+
float S[Q] = { [0 ... Q-1] = 0.0f };
|
3378
|
+
float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
|
3166
3379
|
|
3167
3380
|
// thread indices inside the simdgroup
|
3168
3381
|
// TODO: see if we can utilize quad-group functions for better performance
|
@@ -3177,22 +3390,15 @@ kernel void kernel_flash_attn_ext(
|
|
3177
3390
|
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
3178
3391
|
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
3179
3392
|
|
3180
|
-
// load the queries from shared memory into local memory
|
3181
|
-
q8x8_t mq[D8];
|
3182
|
-
|
3183
|
-
for (short i = 0; i < D8; ++i) {
|
3184
|
-
simdgroup_load(mq[i], sq + i*8, D);
|
3185
|
-
}
|
3186
|
-
|
3187
3393
|
const bool has_mask = mask != q;
|
3188
3394
|
|
3189
|
-
|
3395
|
+
float slope = 1.0f;
|
3190
3396
|
|
3191
3397
|
// ALiBi
|
3192
3398
|
if (args.max_bias > 0.0f) {
|
3193
3399
|
const short h = iq2;
|
3194
3400
|
|
3195
|
-
const
|
3401
|
+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
3196
3402
|
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
3197
3403
|
|
3198
3404
|
slope = pow(base, exph);
|
@@ -3208,14 +3414,14 @@ kernel void kernel_flash_attn_ext(
|
|
3208
3414
|
|
3209
3415
|
if (has_mask) {
|
3210
3416
|
// used to detect blocks full of -INF
|
3211
|
-
|
3417
|
+
float smax = -INFINITY;
|
3212
3418
|
|
3213
3419
|
// load the mask in shared memory
|
3214
3420
|
#pragma unroll(Q)
|
3215
3421
|
for (short j = 0; j < Q; ++j) {
|
3216
3422
|
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
3217
3423
|
|
3218
|
-
const
|
3424
|
+
const float m = pm[ic + tiisg];
|
3219
3425
|
|
3220
3426
|
ss[j*TS + C + tiisg] = m;
|
3221
3427
|
smax = max(smax, m);
|
@@ -3236,20 +3442,22 @@ kernel void kernel_flash_attn_ext(
|
|
3236
3442
|
// this is compile-time check, so it does not have runtime overhead
|
3237
3443
|
if (is_same<kd4x4_t, k4x4_t>::value) {
|
3238
3444
|
// we can read directly from global memory
|
3239
|
-
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.
|
3445
|
+
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
3240
3446
|
|
3241
|
-
#pragma unroll(
|
3242
|
-
for (short i = 0; i <
|
3447
|
+
#pragma unroll(DK8)
|
3448
|
+
for (short i = 0; i < DK8; ++i) {
|
3243
3449
|
k8x8_t mk;
|
3244
|
-
simdgroup_load(mk, pk + i*8, args.
|
3450
|
+
simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
3245
3451
|
|
3246
|
-
|
3452
|
+
q8x8_t mq;
|
3453
|
+
simdgroup_load(mq, sq + i*8, DK);
|
3454
|
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
3247
3455
|
}
|
3248
3456
|
} else {
|
3249
|
-
for (short ii = 0; ii <
|
3250
|
-
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.
|
3457
|
+
for (short ii = 0; ii < DK16; ii += 4) {
|
3458
|
+
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
3251
3459
|
|
3252
|
-
if (
|
3460
|
+
if (DK16%4 == 0) {
|
3253
3461
|
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
3254
3462
|
{
|
3255
3463
|
k4x4_t tmp;
|
@@ -3262,15 +3470,18 @@ kernel void kernel_flash_attn_ext(
|
|
3262
3470
|
#pragma unroll(4)
|
3263
3471
|
for (short k = 0; k < 4; ++k) {
|
3264
3472
|
k8x8_t mk;
|
3473
|
+
q8x8_t mq;
|
3265
3474
|
|
3266
3475
|
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
3267
|
-
|
3476
|
+
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
|
3477
|
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
3268
3478
|
|
3269
3479
|
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
3270
|
-
|
3480
|
+
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
|
3481
|
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
3271
3482
|
}
|
3272
3483
|
} else {
|
3273
|
-
if (ii + tx <
|
3484
|
+
if (ii + tx < DK16) {
|
3274
3485
|
k4x4_t tmp;
|
3275
3486
|
deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
3276
3487
|
sk4x4[4*ty + tx] = tmp;
|
@@ -3278,14 +3489,17 @@ kernel void kernel_flash_attn_ext(
|
|
3278
3489
|
|
3279
3490
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
3280
3491
|
|
3281
|
-
for (short k = 0; k < 4 && ii + k <
|
3492
|
+
for (short k = 0; k < 4 && ii + k < DK16; ++k) {
|
3282
3493
|
k8x8_t mk;
|
3494
|
+
q8x8_t mq;
|
3283
3495
|
|
3284
3496
|
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
3285
|
-
|
3497
|
+
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
|
3498
|
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
3286
3499
|
|
3287
3500
|
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
3288
|
-
|
3501
|
+
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
|
3502
|
+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
3289
3503
|
}
|
3290
3504
|
}
|
3291
3505
|
}
|
@@ -3303,10 +3517,10 @@ kernel void kernel_flash_attn_ext(
|
|
3303
3517
|
// online softmax
|
3304
3518
|
{
|
3305
3519
|
for (ushort j = 0; j < Q; ++j) {
|
3306
|
-
const
|
3520
|
+
const float m = M[j];
|
3307
3521
|
|
3308
3522
|
// scale and apply the logitcap / mask
|
3309
|
-
|
3523
|
+
float s = ss[j*TS + tiisg]*args.scale;
|
3310
3524
|
|
3311
3525
|
if (args.logit_softcap != 0.0f) {
|
3312
3526
|
s = args.logit_softcap*precise::tanh(s);
|
@@ -3317,8 +3531,8 @@ kernel void kernel_flash_attn_ext(
|
|
3317
3531
|
|
3318
3532
|
M[j] = simd_max(max(M[j], s));
|
3319
3533
|
|
3320
|
-
const
|
3321
|
-
const
|
3534
|
+
const float ms = exp(m - M[j]);
|
3535
|
+
const float vs = exp(s - M[j]);
|
3322
3536
|
|
3323
3537
|
S[j] = S[j]*ms + simd_sum(vs);
|
3324
3538
|
|
@@ -3337,8 +3551,8 @@ kernel void kernel_flash_attn_ext(
|
|
3337
3551
|
s8x8_t mm;
|
3338
3552
|
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
3339
3553
|
|
3340
|
-
#pragma unroll(
|
3341
|
-
for (short i = 0; i <
|
3554
|
+
#pragma unroll(DV8)
|
3555
|
+
for (short i = 0; i < DV8; ++i) {
|
3342
3556
|
simdgroup_multiply(lo[i], mm, lo[i]);
|
3343
3557
|
}
|
3344
3558
|
}
|
@@ -3351,20 +3565,20 @@ kernel void kernel_flash_attn_ext(
|
|
3351
3565
|
|
3352
3566
|
if (is_same<vd4x4_t, v4x4_t>::value) {
|
3353
3567
|
// we can read directly from global memory
|
3354
|
-
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.
|
3568
|
+
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
3355
3569
|
|
3356
|
-
#pragma unroll(
|
3357
|
-
for (short i = 0; i <
|
3570
|
+
#pragma unroll(DV8)
|
3571
|
+
for (short i = 0; i < DV8; ++i) {
|
3358
3572
|
v8x8_t mv;
|
3359
|
-
simdgroup_load(mv, pv + i*8, args.
|
3573
|
+
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
3360
3574
|
|
3361
3575
|
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
3362
3576
|
}
|
3363
3577
|
} else {
|
3364
|
-
for (short ii = 0; ii <
|
3365
|
-
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.
|
3578
|
+
for (short ii = 0; ii < DV16; ii += 4) {
|
3579
|
+
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
3366
3580
|
|
3367
|
-
if (
|
3581
|
+
if (DV16%4 == 0) {
|
3368
3582
|
// no need for bound checks
|
3369
3583
|
{
|
3370
3584
|
v4x4_t tmp;
|
@@ -3385,7 +3599,7 @@ kernel void kernel_flash_attn_ext(
|
|
3385
3599
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
3386
3600
|
}
|
3387
3601
|
} else {
|
3388
|
-
if (ii + tx <
|
3602
|
+
if (ii + tx < DV16) {
|
3389
3603
|
v4x4_t tmp;
|
3390
3604
|
deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
3391
3605
|
sv4x4[4*ty + tx] = tmp;
|
@@ -3393,7 +3607,7 @@ kernel void kernel_flash_attn_ext(
|
|
3393
3607
|
|
3394
3608
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
3395
3609
|
|
3396
|
-
for (short k = 0; k < 4 && ii + k <
|
3610
|
+
for (short k = 0; k < 4 && ii + k < DV16; ++k) {
|
3397
3611
|
v8x8_t mv;
|
3398
3612
|
|
3399
3613
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
@@ -3420,15 +3634,15 @@ kernel void kernel_flash_attn_ext(
|
|
3420
3634
|
|
3421
3635
|
// reduce the warps sequentially
|
3422
3636
|
for (ushort sg = 1; sg < nsg; ++sg) {
|
3423
|
-
|
3424
|
-
|
3637
|
+
float S = { 0.0f };
|
3638
|
+
float M = { -__FLT_MAX__/2 };
|
3425
3639
|
|
3426
3640
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3427
3641
|
|
3428
3642
|
// each simdgroup stores its output to shared memory, reusing sq
|
3429
3643
|
if (sgitg == sg) {
|
3430
|
-
for (short i = 0; i <
|
3431
|
-
simdgroup_store(lo[i], so + i*8,
|
3644
|
+
for (short i = 0; i < DV8; ++i) {
|
3645
|
+
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
3432
3646
|
}
|
3433
3647
|
}
|
3434
3648
|
|
@@ -3437,16 +3651,16 @@ kernel void kernel_flash_attn_ext(
|
|
3437
3651
|
// the first simdgroup accumulates the results from the other simdgroups
|
3438
3652
|
if (sgitg == 0) {
|
3439
3653
|
for (short j = 0; j < Q; ++j) {
|
3440
|
-
const
|
3441
|
-
const
|
3654
|
+
const float S0 = ss[j*TS + 0];
|
3655
|
+
const float S1 = ss[j*TS + sg*SH + 0];
|
3442
3656
|
|
3443
|
-
const
|
3444
|
-
const
|
3657
|
+
const float M0 = ss[j*TS + 1];
|
3658
|
+
const float M1 = ss[j*TS + sg*SH + 1];
|
3445
3659
|
|
3446
3660
|
M = max(M0, M1);
|
3447
3661
|
|
3448
|
-
const
|
3449
|
-
const
|
3662
|
+
const float ms0 = exp(M0 - M);
|
3663
|
+
const float ms1 = exp(M1 - M);
|
3450
3664
|
|
3451
3665
|
S = S0*ms0 + S1*ms1;
|
3452
3666
|
|
@@ -3467,11 +3681,11 @@ kernel void kernel_flash_attn_ext(
|
|
3467
3681
|
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
3468
3682
|
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
3469
3683
|
|
3470
|
-
#pragma unroll(
|
3471
|
-
for (short i = 0; i <
|
3684
|
+
#pragma unroll(DV8)
|
3685
|
+
for (short i = 0; i < DV8; ++i) {
|
3472
3686
|
o8x8_t t;
|
3473
3687
|
|
3474
|
-
simdgroup_load (t, so + i*8,
|
3688
|
+
simdgroup_load (t, so + i*8, DV, 0, false);
|
3475
3689
|
simdgroup_multiply(t, ms1, t);
|
3476
3690
|
|
3477
3691
|
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
@@ -3482,8 +3696,8 @@ kernel void kernel_flash_attn_ext(
|
|
3482
3696
|
|
3483
3697
|
// store result to shared memory (reuse sq)
|
3484
3698
|
if (sgitg == 0) {
|
3485
|
-
for (short i = 0; i <
|
3486
|
-
simdgroup_store(lo[i], so + i*8,
|
3699
|
+
for (short i = 0; i < DV8; ++i) {
|
3700
|
+
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
3487
3701
|
}
|
3488
3702
|
}
|
3489
3703
|
|
@@ -3494,8 +3708,8 @@ kernel void kernel_flash_attn_ext(
|
|
3494
3708
|
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
|
3495
3709
|
const float S = ss[j*TS + 0];
|
3496
3710
|
|
3497
|
-
for (short i = tiisg; i <
|
3498
|
-
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*
|
3711
|
+
for (short i = tiisg; i < DV4; i += NW) {
|
3712
|
+
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
|
3499
3713
|
}
|
3500
3714
|
}
|
3501
3715
|
}
|
@@ -3512,80 +3726,101 @@ kernel void kernel_flash_attn_ext(
|
|
3512
3726
|
float, simdgroup_float8x8, \
|
3513
3727
|
half, half4, simdgroup_half8x8
|
3514
3728
|
|
3515
|
-
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
3729
|
+
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
3516
3730
|
|
3517
|
-
template [[host_name("kernel_flash_attn_ext_f16_h64" )]]
|
3518
|
-
template [[host_name("kernel_flash_attn_ext_f16_h80" )]]
|
3519
|
-
template [[host_name("kernel_flash_attn_ext_f16_h96" )]]
|
3520
|
-
template [[host_name("kernel_flash_attn_ext_f16_h112")]]
|
3521
|
-
template [[host_name("kernel_flash_attn_ext_f16_h128")]]
|
3522
|
-
template [[host_name("
|
3731
|
+
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
3732
|
+
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
3733
|
+
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
3734
|
+
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
3735
|
+
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
|
3736
|
+
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
3737
|
+
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
3738
|
+
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
3739
|
+
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
3523
3740
|
|
3524
3741
|
#if defined(GGML_METAL_USE_BF16)
|
3525
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]
|
3526
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]]
|
3527
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]]
|
3528
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h112")]]
|
3529
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h128")]]
|
3530
|
-
template [[host_name("
|
3742
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
3743
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
3744
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
3745
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
3746
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
3747
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
3748
|
+
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
3749
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
3750
|
+
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
3531
3751
|
#endif
|
3532
3752
|
|
3533
|
-
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]
|
3534
|
-
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]]
|
3535
|
-
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]]
|
3536
|
-
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]]
|
3537
|
-
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]
|
3538
|
-
template [[host_name("
|
3539
|
-
|
3540
|
-
template [[host_name("
|
3541
|
-
template [[host_name("
|
3542
|
-
|
3543
|
-
template [[host_name("
|
3544
|
-
template [[host_name("
|
3545
|
-
template [[host_name("
|
3546
|
-
|
3547
|
-
template [[host_name("
|
3548
|
-
template [[host_name("
|
3549
|
-
template [[host_name("
|
3550
|
-
template [[host_name("
|
3551
|
-
template [[host_name("
|
3552
|
-
|
3553
|
-
|
3554
|
-
template [[host_name("
|
3555
|
-
template [[host_name("
|
3556
|
-
template [[host_name("
|
3557
|
-
template [[host_name("
|
3558
|
-
template [[host_name("
|
3559
|
-
template [[host_name("
|
3560
|
-
|
3561
|
-
template [[host_name("
|
3562
|
-
|
3563
|
-
template [[host_name("
|
3564
|
-
template [[host_name("
|
3565
|
-
template [[host_name("
|
3566
|
-
template [[host_name("
|
3753
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
3754
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
3755
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
3756
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
3757
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
|
3758
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
3759
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
3760
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
3761
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
3762
|
+
|
3763
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
3764
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
3765
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
3766
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
3767
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
|
3768
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
3769
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
3770
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
3771
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
3772
|
+
|
3773
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
3774
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
3775
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
3776
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
3777
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
|
3778
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
3779
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
3780
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
3781
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
3782
|
+
|
3783
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
3784
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
3785
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
3786
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
3787
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
|
3788
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
3789
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
3790
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
3791
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
3792
|
+
|
3793
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
3794
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
3795
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
3796
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
3797
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
|
3798
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
3799
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
3800
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
3801
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
3567
3802
|
|
3568
3803
|
#undef FA_TYPES
|
3569
3804
|
|
3570
3805
|
template<
|
3571
|
-
typename q4_t,
|
3572
|
-
typename
|
3573
|
-
typename
|
3574
|
-
typename
|
3575
|
-
typename
|
3576
|
-
typename s_t, // soft-max types
|
3806
|
+
typename q4_t, // query types in shared memory
|
3807
|
+
typename k4_t, // key types in shared memory
|
3808
|
+
typename v4_t, // value types in shared memory
|
3809
|
+
typename qk_t, // Q*K types
|
3810
|
+
typename s_t, // soft-max types
|
3577
3811
|
typename s4_t,
|
3578
|
-
typename
|
3579
|
-
typename
|
3580
|
-
typename kd4x4_t, // key type in device memory
|
3812
|
+
typename o4_t, // attention accumulation types
|
3813
|
+
typename kd4_t, // key type in device memory
|
3581
3814
|
short nl_k,
|
3582
|
-
void (*
|
3583
|
-
typename
|
3815
|
+
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
|
3816
|
+
typename vd4_t, // value type in device memory
|
3584
3817
|
short nl_v,
|
3585
|
-
void (*
|
3586
|
-
short
|
3587
|
-
short
|
3588
|
-
short
|
3818
|
+
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
3819
|
+
short DK, // K head size
|
3820
|
+
short DV, // V head size
|
3821
|
+
short NE = 4, // head elements per thread
|
3822
|
+
short Q = 1, // queries per threadgroup
|
3823
|
+
short C = 32> // cache items per threadgroup
|
3589
3824
|
kernel void kernel_flash_attn_ext_vec(
|
3590
3825
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
3591
3826
|
device const char * q,
|
@@ -3604,29 +3839,28 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3604
3839
|
const int iq2 = tgpig[1];
|
3605
3840
|
const int iq1 = tgpig[0];
|
3606
3841
|
|
3607
|
-
|
3608
|
-
|
3609
|
-
|
3610
|
-
|
3611
|
-
|
3842
|
+
constexpr short DK4 = DK/4;
|
3843
|
+
constexpr short DV4 = DV/4;
|
3844
|
+
constexpr short NW = N_SIMDWIDTH;
|
3845
|
+
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
|
3846
|
+
constexpr short SH = 4*C; // shared memory per simdgroup
|
3612
3847
|
|
3613
|
-
const short T =
|
3848
|
+
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
3614
3849
|
|
3615
|
-
//threadgroup q_t
|
3616
|
-
threadgroup q4_t
|
3617
|
-
threadgroup
|
3618
|
-
threadgroup
|
3619
|
-
threadgroup
|
3620
|
-
threadgroup
|
3621
|
-
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
|
3850
|
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
3851
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
3852
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
3853
|
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
3854
|
+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
3855
|
+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
3622
3856
|
|
3623
|
-
// store the result for all queries in local memory
|
3624
|
-
|
3857
|
+
// store the result for all queries in local memory (the O matrix from the paper)
|
3858
|
+
o4_t lo[DV4/NL];
|
3625
3859
|
|
3626
3860
|
// load heads from Q to shared memory
|
3627
3861
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
3628
3862
|
|
3629
|
-
for (short i = tiisg; i <
|
3863
|
+
for (short i = tiisg; i < DK4; i += NW) {
|
3630
3864
|
if (iq1 < args.ne01) {
|
3631
3865
|
sq4[i] = (q4_t) q4[i];
|
3632
3866
|
} else {
|
@@ -3635,8 +3869,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3635
3869
|
}
|
3636
3870
|
|
3637
3871
|
// zero out lo
|
3638
|
-
for (short i = 0; i <
|
3639
|
-
lo[i] = (
|
3872
|
+
for (short i = 0; i < DV4/NL; ++i) {
|
3873
|
+
lo[i] = (o4_t) 0.0f;
|
3640
3874
|
}
|
3641
3875
|
|
3642
3876
|
// zero out shared memory SH
|
@@ -3647,8 +3881,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3647
3881
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3648
3882
|
|
3649
3883
|
{
|
3650
|
-
|
3651
|
-
|
3884
|
+
float S = 0.0f;
|
3885
|
+
float M = -__FLT_MAX__/2;
|
3652
3886
|
|
3653
3887
|
// thread indices inside the simdgroup
|
3654
3888
|
const short tx = tiisg%NL;
|
@@ -3661,26 +3895,18 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3661
3895
|
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
3662
3896
|
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
3663
3897
|
|
3664
|
-
|
3665
|
-
q4x4_t mq[D16/NL];
|
3666
|
-
|
3667
|
-
#pragma unroll(D16/NL)
|
3668
|
-
for (short ii = 0; ii < D16; ii += NL) {
|
3669
|
-
mq[ii/NL] = sq4x4[ii + tx];
|
3670
|
-
}
|
3671
|
-
|
3672
|
-
const bool has_mask = mask != q;
|
3898
|
+
const bool has_mask = mask != q;
|
3673
3899
|
|
3674
3900
|
// pointer to the mask
|
3675
3901
|
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
3676
3902
|
|
3677
|
-
|
3903
|
+
float slope = 1.0f;
|
3678
3904
|
|
3679
3905
|
// ALiBi
|
3680
3906
|
if (args.max_bias > 0.0f) {
|
3681
3907
|
const short h = iq2;
|
3682
3908
|
|
3683
|
-
const
|
3909
|
+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
3684
3910
|
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
3685
3911
|
|
3686
3912
|
slope = pow(base, exph);
|
@@ -3698,45 +3924,63 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3698
3924
|
sm[tiisg] = pm[ic + tiisg];
|
3699
3925
|
}
|
3700
3926
|
|
3927
|
+
// skip -INF blocks
|
3928
|
+
if (simd_max(sm[tiisg]) == -INFINITY) {
|
3929
|
+
continue;
|
3930
|
+
}
|
3931
|
+
|
3701
3932
|
// Q*K^T
|
3702
3933
|
{
|
3703
|
-
// each simdgroup processes 1 query and
|
3704
|
-
for (short cc = 0; cc < C/
|
3705
|
-
qk_t
|
3934
|
+
// each simdgroup processes 1 query and NE (NW/NL) head elements
|
3935
|
+
for (short cc = 0; cc < C/NE; ++cc) {
|
3936
|
+
qk_t mqk = 0.0f;
|
3706
3937
|
|
3707
|
-
device const
|
3938
|
+
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
3708
3939
|
|
3709
|
-
#pragma unroll(
|
3710
|
-
for (short ii = 0; ii <
|
3940
|
+
#pragma unroll(DK4/NL)
|
3941
|
+
for (short ii = 0; ii < DK4; ii += NL) {
|
3711
3942
|
const short i = ii + tx;
|
3712
3943
|
|
3713
|
-
|
3714
|
-
|
3944
|
+
k4_t mk;
|
3945
|
+
deq_k_t4(pk + i/nl_k, i%nl_k, mk);
|
3715
3946
|
|
3716
3947
|
// note: this is less precise than the version below
|
3717
|
-
//mqka[0] += dot(mq[
|
3718
|
-
//mqka[1] += dot(mq[
|
3719
|
-
//mqka[2] += dot(mq[
|
3720
|
-
//mqka[3] += dot(mq[
|
3721
|
-
|
3722
|
-
|
3723
|
-
mqka[
|
3724
|
-
mqka[
|
3725
|
-
mqka[
|
3948
|
+
//mqka[0] += dot(mq[0], mk[0]);
|
3949
|
+
//mqka[1] += dot(mq[1], mk[1]);
|
3950
|
+
//mqka[2] += dot(mq[2], mk[2]);
|
3951
|
+
//mqka[3] += dot(mq[3], mk[3]);
|
3952
|
+
|
3953
|
+
//q4x4_t mq = sq4x4[i];
|
3954
|
+
//mqka[0] += dot((float4) mq[0], (float4) mk[0]);
|
3955
|
+
//mqka[1] += dot((float4) mq[1], (float4) mk[1]);
|
3956
|
+
//mqka[2] += dot((float4) mq[2], (float4) mk[2]);
|
3957
|
+
//mqka[3] += dot((float4) mq[3], (float4) mk[3]);
|
3958
|
+
|
3959
|
+
mqk += dot((float4) mk, (float4) sq4[i]);
|
3726
3960
|
}
|
3727
3961
|
|
3728
|
-
|
3962
|
+
static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
|
3729
3963
|
|
3730
|
-
// simdgroup reduce
|
3964
|
+
// simdgroup reduce (NE = 4)
|
3731
3965
|
// [ 0 .. 7] -> [ 0]
|
3732
3966
|
// [ 8 .. 15] -> [ 8]
|
3733
3967
|
// [16 .. 23] -> [16]
|
3734
3968
|
// [24 .. 31] -> [24]
|
3735
|
-
|
3736
|
-
|
3737
|
-
|
3738
|
-
|
3739
|
-
|
3969
|
+
if (NE <= 1) {
|
3970
|
+
mqk += simd_shuffle_down(mqk, 16);
|
3971
|
+
}
|
3972
|
+
if (NE <= 2) {
|
3973
|
+
mqk += simd_shuffle_down(mqk, 8);
|
3974
|
+
}
|
3975
|
+
if (NE <= 4) {
|
3976
|
+
mqk += simd_shuffle_down(mqk, 4);
|
3977
|
+
}
|
3978
|
+
if (NE <= 8) {
|
3979
|
+
mqk += simd_shuffle_down(mqk, 2);
|
3980
|
+
}
|
3981
|
+
if (NE <= 16) {
|
3982
|
+
mqk += simd_shuffle_down(mqk, 1);
|
3983
|
+
}
|
3740
3984
|
|
3741
3985
|
// mqk = mqk*scale + mask*slope
|
3742
3986
|
if (tx == 0) {
|
@@ -3746,9 +3990,9 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3746
3990
|
mqk = args.logit_softcap*precise::tanh(mqk);
|
3747
3991
|
}
|
3748
3992
|
|
3749
|
-
mqk += sm[
|
3993
|
+
mqk += sm[NE*cc + ty]*slope;
|
3750
3994
|
|
3751
|
-
ss[
|
3995
|
+
ss[NE*cc + ty] = mqk;
|
3752
3996
|
}
|
3753
3997
|
}
|
3754
3998
|
}
|
@@ -3757,13 +4001,13 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3757
4001
|
|
3758
4002
|
// online softmax
|
3759
4003
|
{
|
3760
|
-
const
|
3761
|
-
const
|
4004
|
+
const float m = M;
|
4005
|
+
const float s = ss[tiisg];
|
3762
4006
|
|
3763
4007
|
M = simd_max(max(M, s));
|
3764
4008
|
|
3765
|
-
const
|
3766
|
-
const
|
4009
|
+
const float ms = exp(m - M);
|
4010
|
+
const float vs = exp(s - M);
|
3767
4011
|
|
3768
4012
|
S = S*ms + simd_sum(vs);
|
3769
4013
|
|
@@ -3771,8 +4015,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3771
4015
|
ss[tiisg] = vs;
|
3772
4016
|
|
3773
4017
|
// O = diag(ms)*O
|
3774
|
-
#pragma unroll(
|
3775
|
-
for (short ii = 0; ii <
|
4018
|
+
#pragma unroll(DV4/NL)
|
4019
|
+
for (short ii = 0; ii < DV4; ii += NL) {
|
3776
4020
|
lo[ii/NL] *= ms;
|
3777
4021
|
}
|
3778
4022
|
}
|
@@ -3781,19 +4025,20 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3781
4025
|
|
3782
4026
|
// O = O + (Q*K^T)*V
|
3783
4027
|
{
|
3784
|
-
|
3785
|
-
|
4028
|
+
//#pragma unroll(C/NE)
|
4029
|
+
for (short cc = 0; cc < C/NE; ++cc) {
|
4030
|
+
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
3786
4031
|
|
3787
|
-
const
|
4032
|
+
const s4_t ms(ss[NE*cc + ty]);
|
3788
4033
|
|
3789
|
-
#pragma unroll(
|
3790
|
-
for (short ii = 0; ii <
|
4034
|
+
#pragma unroll(DV4/NL)
|
4035
|
+
for (short ii = 0; ii < DV4; ii += NL) {
|
3791
4036
|
const short i = ii + tx;
|
3792
4037
|
|
3793
|
-
|
3794
|
-
|
4038
|
+
v4_t mv;
|
4039
|
+
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
|
3795
4040
|
|
3796
|
-
lo[ii/NL] += mv*ms;
|
4041
|
+
lo[ii/NL] += o4_t(float4(mv)*float4(ms));
|
3797
4042
|
}
|
3798
4043
|
}
|
3799
4044
|
}
|
@@ -3806,7 +4051,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3806
4051
|
}
|
3807
4052
|
}
|
3808
4053
|
|
3809
|
-
// simdgroup reduce
|
4054
|
+
// simdgroup reduce (NE = 4)
|
3810
4055
|
// [ 0, 8, 16, 24] -> [ 0]
|
3811
4056
|
// [ 1, 9, 17, 25] -> [ 1]
|
3812
4057
|
// [ 2, 10, 18, 26] -> [ 2]
|
@@ -3815,37 +4060,48 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3815
4060
|
// [ 5, 13, 21, 29] -> [ 5]
|
3816
4061
|
// [ 6, 14, 22, 30] -> [ 6]
|
3817
4062
|
// [ 7, 15, 23, 31] -> [ 7]
|
3818
|
-
for (short ii = 0; ii <
|
3819
|
-
|
3820
|
-
|
3821
|
-
|
3822
|
-
|
3823
|
-
|
3824
|
-
|
3825
|
-
|
3826
|
-
|
3827
|
-
|
3828
|
-
|
3829
|
-
|
3830
|
-
|
3831
|
-
|
3832
|
-
|
3833
|
-
|
3834
|
-
|
3835
|
-
|
3836
|
-
|
3837
|
-
|
3838
|
-
|
3839
|
-
|
3840
|
-
|
3841
|
-
|
4063
|
+
for (short ii = 0; ii < DV4; ii += NL) {
|
4064
|
+
if (NE > 1) {
|
4065
|
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
|
4066
|
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
|
4067
|
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
|
4068
|
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
|
4069
|
+
}
|
4070
|
+
|
4071
|
+
if (NE > 2) {
|
4072
|
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
|
4073
|
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
|
4074
|
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
|
4075
|
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
|
4076
|
+
}
|
4077
|
+
|
4078
|
+
if (NE > 4) {
|
4079
|
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
|
4080
|
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
|
4081
|
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
|
4082
|
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
|
4083
|
+
}
|
4084
|
+
|
4085
|
+
if (NE > 8) {
|
4086
|
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
|
4087
|
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
|
4088
|
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
|
4089
|
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
|
4090
|
+
}
|
4091
|
+
|
4092
|
+
if (NE > 16) {
|
4093
|
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
|
4094
|
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
|
4095
|
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
|
4096
|
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
|
4097
|
+
}
|
3842
4098
|
}
|
3843
4099
|
|
3844
4100
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3845
4101
|
|
3846
4102
|
// store results to shared memory
|
3847
|
-
for (short i = tiisg; i <
|
3848
|
-
|
4103
|
+
for (short i = tiisg; i < DV4; i += NL) {
|
4104
|
+
sr4[i] = lo[i/NL];
|
3849
4105
|
}
|
3850
4106
|
|
3851
4107
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
@@ -3853,18 +4109,18 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3853
4109
|
// parallel reduce
|
3854
4110
|
for (short r = nsg/2; r > 0; r >>= 1) {
|
3855
4111
|
if (sgitg < r) {
|
3856
|
-
const
|
3857
|
-
const
|
4112
|
+
const float S0 = ss[ 0];
|
4113
|
+
const float S1 = ss[r*(SH/2) + 0];
|
3858
4114
|
|
3859
|
-
const
|
3860
|
-
const
|
4115
|
+
const float M0 = ss[ 1];
|
4116
|
+
const float M1 = ss[r*(SH/2) + 1];
|
3861
4117
|
|
3862
|
-
const
|
4118
|
+
const float M = max(M0, M1);
|
3863
4119
|
|
3864
|
-
const
|
3865
|
-
const
|
4120
|
+
const float ms0 = exp(M0 - M);
|
4121
|
+
const float ms1 = exp(M1 - M);
|
3866
4122
|
|
3867
|
-
const
|
4123
|
+
const float S = S0*ms0 + S1*ms1;
|
3868
4124
|
|
3869
4125
|
if (tiisg == 0) {
|
3870
4126
|
ss[0] = S;
|
@@ -3872,22 +4128,22 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3872
4128
|
}
|
3873
4129
|
|
3874
4130
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
3875
|
-
for (short i = tiisg; i <
|
3876
|
-
|
4131
|
+
for (short i = tiisg; i < DV4; i += NW) {
|
4132
|
+
sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
|
3877
4133
|
}
|
3878
4134
|
}
|
3879
4135
|
|
3880
4136
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3881
4137
|
}
|
3882
4138
|
|
3883
|
-
device
|
4139
|
+
device float4 * dst4 = (device float4 *) dst;
|
3884
4140
|
|
3885
4141
|
// final rescale with 1/S and store to global memory
|
3886
4142
|
if (sgitg == 0) {
|
3887
4143
|
const float S = ss[0];
|
3888
4144
|
|
3889
|
-
for (short i = tiisg; i <
|
3890
|
-
|
4145
|
+
for (short i = tiisg; i < DV4; i += NW) {
|
4146
|
+
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
|
3891
4147
|
}
|
3892
4148
|
}
|
3893
4149
|
}
|
@@ -3896,34 +4152,84 @@ kernel void kernel_flash_attn_ext_vec(
|
|
3896
4152
|
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
3897
4153
|
//
|
3898
4154
|
#define FA_TYPES \
|
3899
|
-
half4,
|
3900
|
-
|
3901
|
-
|
3902
|
-
float,
|
3903
|
-
|
3904
|
-
|
4155
|
+
half4, \
|
4156
|
+
half4, \
|
4157
|
+
half4, \
|
4158
|
+
float, \
|
4159
|
+
float, float4, \
|
4160
|
+
half4
|
4161
|
+
|
4162
|
+
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
4163
|
+
|
4164
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
|
4165
|
+
#if defined(GGML_METAL_USE_BF16)
|
4166
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
|
4167
|
+
#endif
|
4168
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 8>;
|
4169
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 8>;
|
4170
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 8>;
|
4171
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 8>;
|
4172
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 8>;
|
4173
|
+
|
4174
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
4175
|
+
#if defined(GGML_METAL_USE_BF16)
|
4176
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
4177
|
+
#endif
|
4178
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
|
4179
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
|
4180
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
|
4181
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
4182
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
4183
|
+
|
4184
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
|
4185
|
+
#if defined(GGML_METAL_USE_BF16)
|
4186
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
|
4187
|
+
#endif
|
4188
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 4>;
|
4189
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 4>;
|
4190
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 4>;
|
4191
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 4>;
|
4192
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
|
3905
4193
|
|
3906
|
-
|
4194
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 4>;
|
4195
|
+
#if defined(GGML_METAL_USE_BF16)
|
4196
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 4>;
|
4197
|
+
#endif
|
4198
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 4>;
|
4199
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 4>;
|
4200
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 4>;
|
4201
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 4>;
|
4202
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 4>;
|
4203
|
+
|
4204
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
|
4205
|
+
#if defined(GGML_METAL_USE_BF16)
|
4206
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 4>;
|
4207
|
+
#endif
|
4208
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 4>;
|
4209
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 4>;
|
4210
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 4>;
|
4211
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 4>;
|
4212
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
|
3907
4213
|
|
3908
|
-
template [[host_name("
|
4214
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
|
3909
4215
|
#if defined(GGML_METAL_USE_BF16)
|
3910
|
-
template [[host_name("
|
4216
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 4>;
|
3911
4217
|
#endif
|
3912
|
-
template [[host_name("
|
3913
|
-
template [[host_name("
|
3914
|
-
template [[host_name("
|
3915
|
-
template [[host_name("
|
3916
|
-
template [[host_name("
|
4218
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 4>;
|
4219
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 4>;
|
4220
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 4>;
|
4221
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
|
4222
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
|
3917
4223
|
|
3918
|
-
template [[host_name("
|
4224
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
3919
4225
|
#if defined(GGML_METAL_USE_BF16)
|
3920
|
-
template [[host_name("
|
4226
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
3921
4227
|
#endif
|
3922
|
-
template [[host_name("
|
3923
|
-
template [[host_name("
|
3924
|
-
template [[host_name("
|
3925
|
-
template [[host_name("
|
3926
|
-
template [[host_name("
|
4228
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
4229
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
4230
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
4231
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
4232
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
3927
4233
|
|
3928
4234
|
#undef FA_TYPES
|
3929
4235
|
|
@@ -4298,7 +4604,7 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
4298
4604
|
float amax = 0.0f; // absolute max
|
4299
4605
|
float max = 0.0f;
|
4300
4606
|
|
4301
|
-
for (int j = 0; j <
|
4607
|
+
for (int j = 0; j < QK4_NL; j++) {
|
4302
4608
|
const float v = src[j];
|
4303
4609
|
if (amax < fabs(v)) {
|
4304
4610
|
amax = fabs(v);
|
@@ -4332,6 +4638,49 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
4332
4638
|
}
|
4333
4639
|
}
|
4334
4640
|
|
4641
|
+
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
4642
|
+
kernel void kernel_cpy_q_f32(
|
4643
|
+
constant ggml_metal_kargs_cpy & args,
|
4644
|
+
device const char * src0,
|
4645
|
+
device char * dst,
|
4646
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4647
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
4648
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
4649
|
+
const int i03 = tgpig[2];
|
4650
|
+
const int i02 = tgpig[1];
|
4651
|
+
const int i01 = tgpig[0];
|
4652
|
+
|
4653
|
+
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
4654
|
+
|
4655
|
+
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
4656
|
+
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
4657
|
+
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
4658
|
+
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
4659
|
+
|
4660
|
+
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
4661
|
+
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
4662
|
+
|
4663
|
+
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
|
4664
|
+
T4x4 temp;
|
4665
|
+
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
4666
|
+
dst_data[i00] = temp;
|
4667
|
+
}
|
4668
|
+
}
|
4669
|
+
|
4670
|
+
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
4671
|
+
|
4672
|
+
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
4673
|
+
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
4674
|
+
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
4675
|
+
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
4676
|
+
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
4677
|
+
|
4678
|
+
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
4679
|
+
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
4680
|
+
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
4681
|
+
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
|
4682
|
+
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
|
4683
|
+
|
4335
4684
|
kernel void kernel_concat(
|
4336
4685
|
constant ggml_metal_kargs_concat & args,
|
4337
4686
|
device const char * src0,
|
@@ -4363,7 +4712,7 @@ kernel void kernel_concat(
|
|
4363
4712
|
}
|
4364
4713
|
}
|
4365
4714
|
|
4366
|
-
template<typename args_t>
|
4715
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
4367
4716
|
void kernel_mul_mv_q2_K_f32_impl(
|
4368
4717
|
args_t args,
|
4369
4718
|
device const char * src0,
|
@@ -4379,7 +4728,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
4379
4728
|
const int r1 = tgpig.y;
|
4380
4729
|
const int im = tgpig.z;
|
4381
4730
|
|
4382
|
-
const int first_row = (r0 *
|
4731
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
4383
4732
|
|
4384
4733
|
const uint i12 = im%args.ne12;
|
4385
4734
|
const uint i13 = im/args.ne12;
|
@@ -4391,20 +4740,19 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
4391
4740
|
device const float * y = (device const float *) (src1 + offset1);
|
4392
4741
|
|
4393
4742
|
float yl[32];
|
4394
|
-
float sumf[
|
4743
|
+
float sumf[nr0]={0.f};
|
4395
4744
|
|
4396
|
-
const
|
4397
|
-
const
|
4398
|
-
const
|
4399
|
-
const
|
4400
|
-
const
|
4745
|
+
const short ix = tiisg/8; // 0...3
|
4746
|
+
const short it = tiisg%8; // 0...7
|
4747
|
+
const short iq = it/4; // 0 or 1
|
4748
|
+
const short ir = it%4; // 0...3
|
4749
|
+
const short is = (8*ir)/16;// 0 or 1
|
4401
4750
|
|
4402
4751
|
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
4403
4752
|
|
4404
4753
|
for (int ib = ix; ib < nb; ib += 4) {
|
4405
|
-
|
4406
4754
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
4407
|
-
for (
|
4755
|
+
for (short i = 0; i < 8; ++i) {
|
4408
4756
|
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
4409
4757
|
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
4410
4758
|
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
@@ -4415,8 +4763,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
4415
4763
|
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
4416
4764
|
device const half * dh = &x[ib].d;
|
4417
4765
|
|
4418
|
-
for (
|
4419
|
-
|
4766
|
+
for (short row = 0; row < nr0; row++) {
|
4420
4767
|
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
4421
4768
|
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
4422
4769
|
for (int i = 0; i < 8; i += 2) {
|
@@ -4447,10 +4794,10 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
4447
4794
|
|
4448
4795
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
4449
4796
|
|
4450
|
-
for (int row = 0; row <
|
4451
|
-
|
4797
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
4798
|
+
float sum_all = simd_sum(sumf[row]);
|
4452
4799
|
if (tiisg == 0) {
|
4453
|
-
dst_f32[first_row + row] =
|
4800
|
+
dst_f32[first_row + row] = sum_all;
|
4454
4801
|
}
|
4455
4802
|
}
|
4456
4803
|
}
|
@@ -4465,10 +4812,10 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
4465
4812
|
ushort tiisg[[thread_index_in_simdgroup]],
|
4466
4813
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
4467
4814
|
|
4468
|
-
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
4815
|
+
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
4469
4816
|
}
|
4470
4817
|
|
4471
|
-
template<typename args_t>
|
4818
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
4472
4819
|
void kernel_mul_mv_q3_K_f32_impl(
|
4473
4820
|
args_t args,
|
4474
4821
|
device const char * src0,
|
@@ -4485,7 +4832,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4485
4832
|
const int r1 = tgpig.y;
|
4486
4833
|
const int im = tgpig.z;
|
4487
4834
|
|
4488
|
-
const int first_row = (r0 *
|
4835
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
4489
4836
|
|
4490
4837
|
const uint i12 = im%args.ne12;
|
4491
4838
|
const uint i13 = im/args.ne12;
|
@@ -4501,13 +4848,12 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4501
4848
|
//const uint16_t kmask1 = 0x3030;
|
4502
4849
|
//const uint16_t kmask2 = 0x0f0f;
|
4503
4850
|
|
4504
|
-
const
|
4505
|
-
const
|
4506
|
-
const
|
4507
|
-
const
|
4508
|
-
const
|
4509
|
-
const
|
4510
|
-
const int l0 = n*ir;
|
4851
|
+
const short tid = tiisg/4;
|
4852
|
+
const short ix = tiisg%4;
|
4853
|
+
const short ip = tid/4; // 0 or 1
|
4854
|
+
const short il = 2*((tid%4)/2); // 0 or 2
|
4855
|
+
const short ir = tid%2;
|
4856
|
+
const short l0 = 8*ir;
|
4511
4857
|
|
4512
4858
|
// One would think that the Metal compiler would figure out that ip and il can only have
|
4513
4859
|
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
@@ -4532,8 +4878,8 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4532
4878
|
const uint16_t s_shift1 = 4*ip;
|
4533
4879
|
const uint16_t s_shift2 = s_shift1 + il;
|
4534
4880
|
|
4535
|
-
const
|
4536
|
-
const
|
4881
|
+
const short q_offset = 32*ip + l0;
|
4882
|
+
const short y_offset = 128*ip + 32*il + l0;
|
4537
4883
|
|
4538
4884
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
4539
4885
|
|
@@ -4541,10 +4887,11 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4541
4887
|
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
4542
4888
|
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
4543
4889
|
|
4544
|
-
float sumf1[
|
4545
|
-
float sumf2[
|
4890
|
+
float sumf1[nr0] = {0.f};
|
4891
|
+
float sumf2[nr0] = {0.f};
|
4892
|
+
|
4546
4893
|
for (int i = ix; i < nb; i += 4) {
|
4547
|
-
for (
|
4894
|
+
for (short l = 0; l < 8; ++l) {
|
4548
4895
|
yl[l+ 0] = y1[l+ 0];
|
4549
4896
|
yl[l+ 8] = y1[l+16];
|
4550
4897
|
yl[l+16] = y1[l+32];
|
@@ -4556,7 +4903,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4556
4903
|
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
4557
4904
|
device const half * dh = &x[i].d;
|
4558
4905
|
|
4559
|
-
for (
|
4906
|
+
for (short row = 0; row < nr0; ++row) {
|
4560
4907
|
const float d_all = (float)dh[0];
|
4561
4908
|
|
4562
4909
|
scales16[0] = a[4];
|
@@ -4567,7 +4914,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4567
4914
|
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
4568
4915
|
|
4569
4916
|
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
4570
|
-
for (
|
4917
|
+
for (short l = 0; l < 8; l += 2) {
|
4571
4918
|
const int32_t qs = q[l/2];
|
4572
4919
|
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
4573
4920
|
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
@@ -4582,7 +4929,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4582
4929
|
sumf2[row] += d2 * (scales[2] - 32);
|
4583
4930
|
|
4584
4931
|
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
4585
|
-
for (
|
4932
|
+
for (short l = 0; l < 8; l += 2) {
|
4586
4933
|
const int32_t qs = q[l/2+8];
|
4587
4934
|
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
4588
4935
|
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
@@ -4605,7 +4952,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4605
4952
|
y1 += 4 * QK_K;
|
4606
4953
|
}
|
4607
4954
|
|
4608
|
-
for (int row = 0; row <
|
4955
|
+
for (int row = 0; row < nr0; ++row) {
|
4609
4956
|
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
4610
4957
|
sumf1[row] = simd_sum(sumf);
|
4611
4958
|
}
|
@@ -4613,7 +4960,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
4613
4960
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
4614
4961
|
|
4615
4962
|
if (tiisg == 0) {
|
4616
|
-
for (int row = 0; row <
|
4963
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
4617
4964
|
dst_f32[first_row + row] = sumf1[row];
|
4618
4965
|
}
|
4619
4966
|
}
|
@@ -4629,10 +4976,10 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
4629
4976
|
ushort tiisg[[thread_index_in_simdgroup]],
|
4630
4977
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
4631
4978
|
|
4632
|
-
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
4979
|
+
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
4633
4980
|
}
|
4634
4981
|
|
4635
|
-
template<typename args_t>
|
4982
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
4636
4983
|
void kernel_mul_mv_q4_K_f32_impl(
|
4637
4984
|
args_t args,
|
4638
4985
|
device const char * src0,
|
@@ -4642,22 +4989,22 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
4642
4989
|
uint3 tgpig,
|
4643
4990
|
ushort tiisg,
|
4644
4991
|
ushort sgitg) {
|
4645
|
-
|
4646
4992
|
const uint16_t kmask1 = 0x3f3f;
|
4647
4993
|
const uint16_t kmask2 = 0x0f0f;
|
4648
4994
|
const uint16_t kmask3 = 0xc0c0;
|
4649
4995
|
|
4650
|
-
const
|
4651
|
-
const
|
4652
|
-
const
|
4653
|
-
const
|
4996
|
+
const short ix = tiisg/8; // 0...3
|
4997
|
+
const short it = tiisg%8; // 0...7
|
4998
|
+
const short iq = it/4; // 0 or 1
|
4999
|
+
const short ir = it%4; // 0...3
|
4654
5000
|
|
4655
5001
|
const int nb = args.ne00/QK_K;
|
5002
|
+
|
4656
5003
|
const int r0 = tgpig.x;
|
4657
5004
|
const int r1 = tgpig.y;
|
4658
5005
|
const int im = tgpig.z;
|
4659
|
-
|
4660
|
-
const int first_row = r0 *
|
5006
|
+
|
5007
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
4661
5008
|
|
4662
5009
|
const uint i12 = im%args.ne12;
|
4663
5010
|
const uint i13 = im/args.ne12;
|
@@ -4670,7 +5017,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
4670
5017
|
|
4671
5018
|
float yl[16];
|
4672
5019
|
float yh[16];
|
4673
|
-
|
5020
|
+
|
5021
|
+
float sumf[nr0]={0.f};
|
4674
5022
|
|
4675
5023
|
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
4676
5024
|
|
@@ -4679,7 +5027,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
4679
5027
|
|
4680
5028
|
for (int ib = ix; ib < nb; ib += 4) {
|
4681
5029
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
4682
|
-
|
5030
|
+
|
5031
|
+
for (short i = 0; i < 8; ++i) {
|
4683
5032
|
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
4684
5033
|
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
4685
5034
|
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
@@ -4690,7 +5039,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
4690
5039
|
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
4691
5040
|
device const half * dh = &x[ib].d;
|
4692
5041
|
|
4693
|
-
for (
|
5042
|
+
for (short row = 0; row < nr0; row++) {
|
4694
5043
|
sc16[0] = sc[0] & kmask1;
|
4695
5044
|
sc16[1] = sc[2] & kmask1;
|
4696
5045
|
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
@@ -4700,19 +5049,21 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
4700
5049
|
|
4701
5050
|
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
4702
5051
|
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
4703
|
-
|
4704
|
-
|
4705
|
-
acc1[
|
4706
|
-
acc1[
|
4707
|
-
acc1[
|
4708
|
-
|
4709
|
-
acc2[
|
4710
|
-
acc2[
|
4711
|
-
acc2[
|
5052
|
+
|
5053
|
+
for (short i = 0; i < 4; ++i) {
|
5054
|
+
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
|
5055
|
+
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
|
5056
|
+
acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
|
5057
|
+
acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
|
5058
|
+
acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
|
5059
|
+
acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
|
5060
|
+
acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
|
5061
|
+
acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
|
4712
5062
|
}
|
4713
5063
|
|
4714
5064
|
float dall = dh[0];
|
4715
5065
|
float dmin = dh[1];
|
5066
|
+
|
4716
5067
|
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
4717
5068
|
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
4718
5069
|
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
@@ -4729,10 +5080,10 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
4729
5080
|
|
4730
5081
|
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
4731
5082
|
|
4732
|
-
for (int row = 0; row <
|
4733
|
-
|
5083
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
5084
|
+
float sum_all = simd_sum(sumf[row]);
|
4734
5085
|
if (tiisg == 0) {
|
4735
|
-
dst_f32[first_row + row] =
|
5086
|
+
dst_f32[first_row + row] = sum_all;
|
4736
5087
|
}
|
4737
5088
|
}
|
4738
5089
|
}
|
@@ -4747,10 +5098,10 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
4747
5098
|
ushort tiisg[[thread_index_in_simdgroup]],
|
4748
5099
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
4749
5100
|
|
4750
|
-
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
5101
|
+
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
4751
5102
|
}
|
4752
5103
|
|
4753
|
-
template<typename args_t>
|
5104
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
4754
5105
|
void kernel_mul_mv_q5_K_f32_impl(
|
4755
5106
|
args_t args,
|
4756
5107
|
device const char * src0,
|
@@ -4767,7 +5118,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4767
5118
|
const int r1 = tgpig.y;
|
4768
5119
|
const int im = tgpig.z;
|
4769
5120
|
|
4770
|
-
const int first_row = (r0 *
|
5121
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
4771
5122
|
|
4772
5123
|
const uint i12 = im%args.ne12;
|
4773
5124
|
const uint i13 = im/args.ne12;
|
@@ -4778,7 +5129,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4778
5129
|
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
|
4779
5130
|
device const float * yy = (device const float *) (src1 + offset1);
|
4780
5131
|
|
4781
|
-
float sumf[
|
5132
|
+
float sumf[nr0]={0.f};
|
4782
5133
|
|
4783
5134
|
float yl[16], yh[16];
|
4784
5135
|
|
@@ -4786,15 +5137,14 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4786
5137
|
const uint16_t kmask2 = 0x0f0f;
|
4787
5138
|
const uint16_t kmask3 = 0xc0c0;
|
4788
5139
|
|
4789
|
-
const
|
4790
|
-
const
|
4791
|
-
const
|
4792
|
-
const
|
4793
|
-
const int n = 8;
|
5140
|
+
const short tid = tiisg/4;
|
5141
|
+
const short ix = tiisg%4;
|
5142
|
+
const short iq = tid/4;
|
5143
|
+
const short ir = tid%4;
|
4794
5144
|
|
4795
|
-
const
|
4796
|
-
const
|
4797
|
-
const
|
5145
|
+
const short l0 = 8*ir;
|
5146
|
+
const short q_offset = 32*iq + l0;
|
5147
|
+
const short y_offset = 64*iq + l0;
|
4798
5148
|
|
4799
5149
|
const uint8_t hm1 = 1u << (2*iq);
|
4800
5150
|
const uint8_t hm2 = hm1 << 1;
|
@@ -4814,14 +5164,14 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4814
5164
|
|
4815
5165
|
device const float * y2 = y1 + 128;
|
4816
5166
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
4817
|
-
for (
|
5167
|
+
for (short l = 0; l < 8; ++l) {
|
4818
5168
|
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
4819
5169
|
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
4820
5170
|
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
4821
5171
|
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
4822
5172
|
}
|
4823
5173
|
|
4824
|
-
for (
|
5174
|
+
for (short row = 0; row < nr0; ++row) {
|
4825
5175
|
device const uint8_t * q2 = q1 + 64;
|
4826
5176
|
|
4827
5177
|
sc16[0] = a[0] & kmask1;
|
@@ -4831,7 +5181,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4831
5181
|
|
4832
5182
|
float4 acc1 = {0.f};
|
4833
5183
|
float4 acc2 = {0.f};
|
4834
|
-
for (
|
5184
|
+
for (short l = 0; l < 8; ++l) {
|
4835
5185
|
uint8_t h = qh[l];
|
4836
5186
|
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
4837
5187
|
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
@@ -4861,7 +5211,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4861
5211
|
|
4862
5212
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
4863
5213
|
|
4864
|
-
for (int row = 0; row <
|
5214
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
4865
5215
|
const float tot = simd_sum(sumf[row]);
|
4866
5216
|
if (tiisg == 0) {
|
4867
5217
|
dst_f32[first_row + row] = tot;
|
@@ -4879,10 +5229,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
4879
5229
|
ushort tiisg[[thread_index_in_simdgroup]],
|
4880
5230
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
4881
5231
|
|
4882
|
-
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
5232
|
+
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
4883
5233
|
}
|
4884
5234
|
|
4885
|
-
template
|
5235
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
4886
5236
|
void kernel_mul_mv_q6_K_f32_impl(
|
4887
5237
|
args_t args,
|
4888
5238
|
device const char * src0,
|
@@ -4904,58 +5254,77 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4904
5254
|
const int r1 = tgpig.y;
|
4905
5255
|
const int im = tgpig.z;
|
4906
5256
|
|
4907
|
-
const int
|
5257
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
4908
5258
|
|
4909
5259
|
const uint i12 = im%args.ne12;
|
4910
5260
|
const uint i13 = im/args.ne12;
|
4911
5261
|
|
4912
|
-
const uint64_t offset0 =
|
4913
|
-
const uint64_t offset1 =
|
5262
|
+
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
5263
|
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
4914
5264
|
|
4915
5265
|
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
4916
5266
|
device const float * yy = (device const float *) (src1 + offset1);
|
4917
5267
|
|
4918
|
-
float sumf = 0;
|
5268
|
+
float sumf[nr0] = { 0.f };
|
4919
5269
|
|
4920
|
-
|
4921
|
-
|
4922
|
-
const
|
4923
|
-
const
|
4924
|
-
const
|
4925
|
-
const
|
4926
|
-
const
|
5270
|
+
float yl[16];
|
5271
|
+
|
5272
|
+
const short tid = tiisg/2;
|
5273
|
+
const short ix = tiisg%2;
|
5274
|
+
const short ip = tid/8; // 0 or 1
|
5275
|
+
const short il = tid%8;
|
5276
|
+
const short l0 = 4*il;
|
5277
|
+
const short is = 8*ip + l0/16;
|
4927
5278
|
|
4928
|
-
const
|
4929
|
-
const
|
4930
|
-
const
|
5279
|
+
const short y_offset = 128*ip + l0;
|
5280
|
+
const short q_offset_l = 64*ip + l0;
|
5281
|
+
const short q_offset_h = 32*ip + l0;
|
4931
5282
|
|
4932
5283
|
for (int i = ix; i < nb; i += 2) {
|
4933
5284
|
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
4934
5285
|
device const uint8_t * q2 = q1 + 32;
|
4935
5286
|
device const uint8_t * qh = x[i].qh + q_offset_h;
|
4936
5287
|
device const int8_t * sc = x[i].scales + is;
|
5288
|
+
device const half * dh = &x[i].d;
|
4937
5289
|
|
4938
5290
|
device const float * y = yy + i * QK_K + y_offset;
|
4939
5291
|
|
4940
|
-
|
4941
|
-
|
4942
|
-
|
4943
|
-
|
4944
|
-
|
4945
|
-
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
4946
|
-
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
4947
|
-
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
5292
|
+
for (short l = 0; l < 4; ++l) {
|
5293
|
+
yl[4*l + 0] = y[l + 0];
|
5294
|
+
yl[4*l + 1] = y[l + 32];
|
5295
|
+
yl[4*l + 2] = y[l + 64];
|
5296
|
+
yl[4*l + 3] = y[l + 96];
|
4948
5297
|
}
|
4949
5298
|
|
4950
|
-
|
5299
|
+
for (short row = 0; row < nr0; ++row) {
|
5300
|
+
const float dall = dh[0];
|
4951
5301
|
|
5302
|
+
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
5303
|
+
|
5304
|
+
for (short l = 0; l < 4; ++l) {
|
5305
|
+
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
5306
|
+
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
5307
|
+
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
5308
|
+
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
5309
|
+
}
|
5310
|
+
|
5311
|
+
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
5312
|
+
|
5313
|
+
q1 += args.nb01;
|
5314
|
+
q2 += args.nb01;
|
5315
|
+
qh += args.nb01;
|
5316
|
+
sc += args.nb01;
|
5317
|
+
dh += args.nb01/2;
|
5318
|
+
}
|
4952
5319
|
}
|
4953
5320
|
|
4954
5321
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
4955
5322
|
|
4956
|
-
|
4957
|
-
|
4958
|
-
|
5323
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
5324
|
+
float sum_all = simd_sum(sumf[row]);
|
5325
|
+
if (tiisg == 0) {
|
5326
|
+
dst_f32[first_row + row] = sum_all;
|
5327
|
+
}
|
4959
5328
|
}
|
4960
5329
|
}
|
4961
5330
|
|
@@ -4969,12 +5338,12 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
4969
5338
|
ushort tiisg[[thread_index_in_simdgroup]],
|
4970
5339
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
4971
5340
|
|
4972
|
-
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
5341
|
+
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
4973
5342
|
}
|
4974
5343
|
|
4975
5344
|
// ======================= "True" 2-bit
|
4976
5345
|
|
4977
|
-
template<typename args_t>
|
5346
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
4978
5347
|
void kernel_mul_mv_iq2_xxs_f32_impl(
|
4979
5348
|
args_t args,
|
4980
5349
|
device const char * src0,
|
@@ -4990,7 +5359,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
4990
5359
|
const int r1 = tgpig.y;
|
4991
5360
|
const int im = tgpig.z;
|
4992
5361
|
|
4993
|
-
const int first_row = (r0 *
|
5362
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
4994
5363
|
|
4995
5364
|
const uint i12 = im%args.ne12;
|
4996
5365
|
const uint i13 = im/args.ne12;
|
@@ -5002,7 +5371,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
5002
5371
|
device const float * y = (device const float *) (src1 + offset1);
|
5003
5372
|
|
5004
5373
|
float yl[32];
|
5005
|
-
float sumf[
|
5374
|
+
float sumf[nr0]={0.f};
|
5006
5375
|
|
5007
5376
|
const int nb32 = nb * (QK_K / 32);
|
5008
5377
|
|
@@ -5023,8 +5392,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
5023
5392
|
device const float * y4 = y + 32 * ix;
|
5024
5393
|
|
5025
5394
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5026
|
-
|
5027
|
-
for (int i = 0; i < 32; ++i) {
|
5395
|
+
for (short i = 0; i < 32; ++i) {
|
5028
5396
|
yl[i] = y4[i];
|
5029
5397
|
}
|
5030
5398
|
|
@@ -5035,18 +5403,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
5035
5403
|
device const uint16_t * q2 = xr->qs + 4 * ib;
|
5036
5404
|
device const half * dh = &xr->d;
|
5037
5405
|
|
5038
|
-
for (
|
5039
|
-
|
5406
|
+
for (short row = 0; row < nr0; row++) {
|
5040
5407
|
const float db = dh[0];
|
5041
5408
|
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
5042
5409
|
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
5043
5410
|
const float d = db * (0.5f + (aux32 >> 28));
|
5044
5411
|
|
5045
5412
|
float sum = 0;
|
5046
|
-
for (
|
5413
|
+
for (short l = 0; l < 4; ++l) {
|
5047
5414
|
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
|
5048
5415
|
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
5049
|
-
for (
|
5416
|
+
for (short j = 0; j < 8; ++j) {
|
5050
5417
|
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
5051
5418
|
}
|
5052
5419
|
}
|
@@ -5061,10 +5428,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
5061
5428
|
|
5062
5429
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5063
5430
|
|
5064
|
-
for (int row = 0; row <
|
5065
|
-
|
5431
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
5432
|
+
float sum_all = simd_sum(sumf[row]);
|
5066
5433
|
if (tiisg == 0) {
|
5067
|
-
dst_f32[first_row + row] =
|
5434
|
+
dst_f32[first_row + row] = sum_all * 0.25f;
|
5068
5435
|
}
|
5069
5436
|
}
|
5070
5437
|
}
|
@@ -5079,10 +5446,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
5079
5446
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5080
5447
|
ushort tiisg[[thread_index_in_simdgroup]],
|
5081
5448
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5082
|
-
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5449
|
+
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5083
5450
|
}
|
5084
5451
|
|
5085
|
-
template<typename args_t>
|
5452
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
5086
5453
|
void kernel_mul_mv_iq2_xs_f32_impl(
|
5087
5454
|
args_t args,
|
5088
5455
|
device const char * src0,
|
@@ -5098,7 +5465,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
5098
5465
|
const int r1 = tgpig.y;
|
5099
5466
|
const int im = tgpig.z;
|
5100
5467
|
|
5101
|
-
const int first_row = (r0 *
|
5468
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
5102
5469
|
|
5103
5470
|
const uint i12 = im%args.ne12;
|
5104
5471
|
const uint i13 = im/args.ne12;
|
@@ -5110,7 +5477,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
5110
5477
|
device const float * y = (device const float *) (src1 + offset1);
|
5111
5478
|
|
5112
5479
|
float yl[32];
|
5113
|
-
float sumf[
|
5480
|
+
float sumf[nr0]={0.f};
|
5114
5481
|
|
5115
5482
|
const int nb32 = nb * (QK_K / 32);
|
5116
5483
|
|
@@ -5131,8 +5498,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
5131
5498
|
device const float * y4 = y + 32 * ix;
|
5132
5499
|
|
5133
5500
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5134
|
-
|
5135
|
-
for (int i = 0; i < 32; ++i) {
|
5501
|
+
for (short i = 0; i < 32; ++i) {
|
5136
5502
|
yl[i] = y4[i];
|
5137
5503
|
}
|
5138
5504
|
|
@@ -5144,8 +5510,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
5144
5510
|
device const uint8_t * sc = xr->scales + ib;
|
5145
5511
|
device const half * dh = &xr->d;
|
5146
5512
|
|
5147
|
-
for (
|
5148
|
-
|
5513
|
+
for (short row = 0; row < nr0; row++) {
|
5149
5514
|
const float db = dh[0];
|
5150
5515
|
const uint8_t ls1 = sc[0] & 0xf;
|
5151
5516
|
const uint8_t ls2 = sc[0] >> 4;
|
@@ -5153,17 +5518,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
5153
5518
|
const float d2 = db * (0.5f + ls2);
|
5154
5519
|
|
5155
5520
|
float sum1 = 0, sum2 = 0;
|
5156
|
-
for (
|
5521
|
+
for (short l = 0; l < 2; ++l) {
|
5157
5522
|
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
5158
5523
|
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
5159
|
-
for (
|
5524
|
+
for (short j = 0; j < 8; ++j) {
|
5160
5525
|
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
5161
5526
|
}
|
5162
5527
|
}
|
5163
|
-
for (
|
5528
|
+
for (short l = 2; l < 4; ++l) {
|
5164
5529
|
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
5165
5530
|
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
5166
|
-
for (
|
5531
|
+
for (short j = 0; j < 8; ++j) {
|
5167
5532
|
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
5168
5533
|
}
|
5169
5534
|
}
|
@@ -5179,10 +5544,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
5179
5544
|
|
5180
5545
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5181
5546
|
|
5182
|
-
for (int row = 0; row <
|
5183
|
-
|
5547
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
5548
|
+
float sum_all = simd_sum(sumf[row]);
|
5184
5549
|
if (tiisg == 0) {
|
5185
|
-
dst_f32[first_row + row] =
|
5550
|
+
dst_f32[first_row + row] = sum_all * 0.25f;
|
5186
5551
|
}
|
5187
5552
|
}
|
5188
5553
|
}
|
@@ -5198,10 +5563,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|
5198
5563
|
ushort tiisg[[thread_index_in_simdgroup]],
|
5199
5564
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5200
5565
|
|
5201
|
-
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5566
|
+
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5202
5567
|
}
|
5203
5568
|
|
5204
|
-
template
|
5569
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
5205
5570
|
void kernel_mul_mv_iq3_xxs_f32_impl(
|
5206
5571
|
args_t args,
|
5207
5572
|
device const char * src0,
|
@@ -5217,7 +5582,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
5217
5582
|
const int r1 = tgpig.y;
|
5218
5583
|
const int im = tgpig.z;
|
5219
5584
|
|
5220
|
-
const int first_row = (r0 *
|
5585
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
5221
5586
|
|
5222
5587
|
const uint i12 = im%args.ne12;
|
5223
5588
|
const uint i13 = im/args.ne12;
|
@@ -5229,7 +5594,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
5229
5594
|
device const float * y = (device const float *) (src1 + offset1);
|
5230
5595
|
|
5231
5596
|
float yl[32];
|
5232
|
-
float sumf[
|
5597
|
+
float sumf[nr0]={0.f};
|
5233
5598
|
|
5234
5599
|
const int nb32 = nb * (QK_K / 32);
|
5235
5600
|
|
@@ -5250,7 +5615,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
5250
5615
|
device const float * y4 = y + 32 * ix;
|
5251
5616
|
|
5252
5617
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5253
|
-
for (
|
5618
|
+
for (short i = 0; i < 32; ++i) {
|
5254
5619
|
yl[i] = y4[i];
|
5255
5620
|
}
|
5256
5621
|
|
@@ -5262,17 +5627,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
5262
5627
|
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
|
5263
5628
|
device const half * dh = &xr->d;
|
5264
5629
|
|
5265
|
-
for (
|
5630
|
+
for (short row = 0; row < nr0; row++) {
|
5266
5631
|
const float db = dh[0];
|
5267
5632
|
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
5268
5633
|
const float d = db * (0.5f + (aux32 >> 28));
|
5269
5634
|
|
5270
5635
|
float2 sum = {0};
|
5271
|
-
for (
|
5636
|
+
for (short l = 0; l < 4; ++l) {
|
5272
5637
|
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
|
5273
5638
|
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
|
5274
5639
|
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
5275
|
-
for (
|
5640
|
+
for (short j = 0; j < 4; ++j) {
|
5276
5641
|
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
5277
5642
|
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
5278
5643
|
}
|
@@ -5289,10 +5654,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
5289
5654
|
|
5290
5655
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5291
5656
|
|
5292
|
-
for (int row = 0; row <
|
5293
|
-
|
5657
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
5658
|
+
float sum_all = simd_sum(sumf[row]);
|
5294
5659
|
if (tiisg == 0) {
|
5295
|
-
dst_f32[first_row + row] =
|
5660
|
+
dst_f32[first_row + row] = sum_all * 0.5f;
|
5296
5661
|
}
|
5297
5662
|
}
|
5298
5663
|
}
|
@@ -5308,10 +5673,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
5308
5673
|
ushort tiisg[[thread_index_in_simdgroup]],
|
5309
5674
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5310
5675
|
|
5311
|
-
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5676
|
+
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5312
5677
|
}
|
5313
5678
|
|
5314
|
-
template<typename args_t>
|
5679
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
5315
5680
|
void kernel_mul_mv_iq3_s_f32_impl(
|
5316
5681
|
args_t args,
|
5317
5682
|
device const char * src0,
|
@@ -5327,7 +5692,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
5327
5692
|
const int r1 = tgpig.y;
|
5328
5693
|
const int im = tgpig.z;
|
5329
5694
|
|
5330
|
-
const int first_row = (r0 *
|
5695
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
5331
5696
|
|
5332
5697
|
const uint i12 = im%args.ne12;
|
5333
5698
|
const uint i13 = im/args.ne12;
|
@@ -5339,7 +5704,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
5339
5704
|
device const float * y = (device const float *) (src1 + offset1);
|
5340
5705
|
|
5341
5706
|
float yl[32];
|
5342
|
-
float sumf[
|
5707
|
+
float sumf[nr0]={0.f};
|
5343
5708
|
|
5344
5709
|
const int nb32 = nb * (QK_K / 32);
|
5345
5710
|
|
@@ -5356,8 +5721,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
5356
5721
|
device const float * y4 = y + 32 * ix;
|
5357
5722
|
|
5358
5723
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5359
|
-
|
5360
|
-
for (int i = 0; i < 32; ++i) {
|
5724
|
+
for (short i = 0; i < 32; ++i) {
|
5361
5725
|
yl[i] = y4[i];
|
5362
5726
|
}
|
5363
5727
|
|
@@ -5371,18 +5735,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
5371
5735
|
device const uint8_t * signs = xr->signs + 4 * ib;
|
5372
5736
|
device const half * dh = &xr->d;
|
5373
5737
|
|
5374
|
-
for (
|
5375
|
-
|
5738
|
+
for (short row = 0; row < nr0; row++) {
|
5376
5739
|
const float db = dh[0];
|
5377
5740
|
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
|
5378
5741
|
|
5379
5742
|
float2 sum = {0};
|
5380
|
-
for (
|
5743
|
+
for (short l = 0; l < 4; ++l) {
|
5381
5744
|
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
|
5382
5745
|
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
|
5383
5746
|
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
5384
5747
|
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
5385
|
-
for (
|
5748
|
+
for (short j = 0; j < 4; ++j) {
|
5386
5749
|
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
5387
5750
|
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
5388
5751
|
}
|
@@ -5401,10 +5764,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
5401
5764
|
|
5402
5765
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5403
5766
|
|
5404
|
-
for (int row = 0; row <
|
5405
|
-
|
5767
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
5768
|
+
float sum_all = simd_sum(sumf[row]);
|
5406
5769
|
if (tiisg == 0) {
|
5407
|
-
dst_f32[first_row + row] =
|
5770
|
+
dst_f32[first_row + row] = sum_all;
|
5408
5771
|
}
|
5409
5772
|
}
|
5410
5773
|
}
|
@@ -5420,10 +5783,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
|
5420
5783
|
ushort tiisg[[thread_index_in_simdgroup]],
|
5421
5784
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5422
5785
|
|
5423
|
-
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5786
|
+
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5424
5787
|
}
|
5425
5788
|
|
5426
|
-
template
|
5789
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
5427
5790
|
void kernel_mul_mv_iq2_s_f32_impl(
|
5428
5791
|
args_t args,
|
5429
5792
|
device const char * src0,
|
@@ -5439,7 +5802,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
5439
5802
|
const int r1 = tgpig.y;
|
5440
5803
|
const int im = tgpig.z;
|
5441
5804
|
|
5442
|
-
const int first_row = (r0 *
|
5805
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
5443
5806
|
|
5444
5807
|
const uint i12 = im%args.ne12;
|
5445
5808
|
const uint i13 = im/args.ne12;
|
@@ -5451,7 +5814,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
5451
5814
|
device const float * y = (device const float *) (src1 + offset1);
|
5452
5815
|
|
5453
5816
|
float yl[32];
|
5454
|
-
float sumf[
|
5817
|
+
float sumf[nr0]={0.f};
|
5455
5818
|
|
5456
5819
|
const int nb32 = nb * (QK_K / 32);
|
5457
5820
|
|
@@ -5463,13 +5826,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
5463
5826
|
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
5464
5827
|
//}
|
5465
5828
|
|
5466
|
-
const
|
5829
|
+
const short ix = tiisg;
|
5467
5830
|
|
5468
5831
|
device const float * y4 = y + 32 * ix;
|
5469
5832
|
|
5470
5833
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5471
|
-
|
5472
|
-
for (int i = 0; i < 32; ++i) {
|
5834
|
+
for (short i = 0; i < 32; ++i) {
|
5473
5835
|
yl[i] = y4[i];
|
5474
5836
|
}
|
5475
5837
|
|
@@ -5483,19 +5845,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
5483
5845
|
device const uint8_t * signs = qs + QK_K/8;
|
5484
5846
|
device const half * dh = &xr->d;
|
5485
5847
|
|
5486
|
-
for (
|
5487
|
-
|
5848
|
+
for (short row = 0; row < nr0; row++) {
|
5488
5849
|
const float db = dh[0];
|
5489
5850
|
const float d1 = db * (0.5f + (sc[0] & 0xf));
|
5490
5851
|
const float d2 = db * (0.5f + (sc[0] >> 4));
|
5491
5852
|
|
5492
5853
|
float2 sum = {0};
|
5493
|
-
for (
|
5854
|
+
for (short l = 0; l < 2; ++l) {
|
5494
5855
|
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
5495
5856
|
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
5496
5857
|
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
5497
5858
|
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
5498
|
-
for (
|
5859
|
+
for (short j = 0; j < 8; ++j) {
|
5499
5860
|
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
|
5500
5861
|
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
|
5501
5862
|
}
|
@@ -5514,10 +5875,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
5514
5875
|
|
5515
5876
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5516
5877
|
|
5517
|
-
for (int row = 0; row <
|
5518
|
-
|
5878
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
5879
|
+
float sum_all = simd_sum(sumf[row]);
|
5519
5880
|
if (tiisg == 0) {
|
5520
|
-
dst_f32[first_row + row] =
|
5881
|
+
dst_f32[first_row + row] = sum_all * 0.25f;
|
5521
5882
|
}
|
5522
5883
|
}
|
5523
5884
|
}
|
@@ -5533,10 +5894,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
|
5533
5894
|
ushort tiisg[[thread_index_in_simdgroup]],
|
5534
5895
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5535
5896
|
|
5536
|
-
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5897
|
+
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5537
5898
|
}
|
5538
5899
|
|
5539
|
-
template<typename args_t>
|
5900
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
5540
5901
|
void kernel_mul_mv_iq1_s_f32_impl(
|
5541
5902
|
args_t args,
|
5542
5903
|
device const char * src0,
|
@@ -5552,7 +5913,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
5552
5913
|
const int r1 = tgpig.y;
|
5553
5914
|
const int im = tgpig.z;
|
5554
5915
|
|
5555
|
-
const int first_row = (r0 *
|
5916
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
5556
5917
|
|
5557
5918
|
const uint i12 = im%args.ne12;
|
5558
5919
|
const uint i13 = im/args.ne12;
|
@@ -5564,18 +5925,17 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
5564
5925
|
device const float * y = (device const float *) (src1 + offset1);
|
5565
5926
|
|
5566
5927
|
float yl[32];
|
5567
|
-
float sumf[
|
5928
|
+
float sumf[nr0]={0.f};
|
5568
5929
|
|
5569
5930
|
const int nb32 = nb * (QK_K / 32);
|
5570
5931
|
|
5571
|
-
const
|
5932
|
+
const short ix = tiisg;
|
5572
5933
|
|
5573
5934
|
device const float * y4 = y + 32 * ix;
|
5574
5935
|
|
5575
5936
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5576
|
-
|
5577
5937
|
float sumy = 0;
|
5578
|
-
for (
|
5938
|
+
for (short i = 0; i < 32; ++i) {
|
5579
5939
|
yl[i] = y4[i];
|
5580
5940
|
sumy += yl[i];
|
5581
5941
|
}
|
@@ -5588,15 +5948,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
5588
5948
|
device const uint16_t * qh = xr->qh + ib;
|
5589
5949
|
device const half * dh = &xr->d;
|
5590
5950
|
|
5591
|
-
for (
|
5592
|
-
|
5951
|
+
for (short row = 0; row < nr0; row++) {
|
5593
5952
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5594
5953
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
|
5595
5954
|
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
|
5596
5955
|
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
|
5597
5956
|
|
5598
5957
|
float sum = 0;
|
5599
|
-
for (
|
5958
|
+
for (short j = 0; j < 4; ++j) {
|
5600
5959
|
sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
5601
5960
|
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
|
5602
5961
|
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
@@ -5614,15 +5973,28 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
5614
5973
|
|
5615
5974
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5616
5975
|
|
5617
|
-
for (int row = 0; row <
|
5618
|
-
|
5976
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
5977
|
+
float sum_all = simd_sum(sumf[row]);
|
5619
5978
|
if (tiisg == 0) {
|
5620
|
-
dst_f32[first_row + row] =
|
5979
|
+
dst_f32[first_row + row] = sum_all;
|
5621
5980
|
}
|
5622
5981
|
}
|
5623
5982
|
}
|
5624
5983
|
|
5625
|
-
|
5984
|
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5985
|
+
kernel void kernel_mul_mv_iq1_s_f32(
|
5986
|
+
constant ggml_metal_kargs_mul_mv & args,
|
5987
|
+
device const char * src0,
|
5988
|
+
device const char * src1,
|
5989
|
+
device char * dst,
|
5990
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5991
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
5992
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5993
|
+
|
5994
|
+
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
5995
|
+
}
|
5996
|
+
|
5997
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
5626
5998
|
void kernel_mul_mv_iq1_m_f32_impl(
|
5627
5999
|
args_t args,
|
5628
6000
|
device const char * src0,
|
@@ -5634,11 +6006,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5634
6006
|
ushort sgitg) {
|
5635
6007
|
|
5636
6008
|
const int nb = args.ne00/QK_K;
|
6009
|
+
|
5637
6010
|
const int r0 = tgpig.x;
|
5638
6011
|
const int r1 = tgpig.y;
|
5639
6012
|
const int im = tgpig.z;
|
5640
6013
|
|
5641
|
-
const int first_row = (r0 *
|
6014
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
5642
6015
|
|
5643
6016
|
const uint i12 = im%args.ne12;
|
5644
6017
|
const uint i13 = im/args.ne12;
|
@@ -5650,20 +6023,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5650
6023
|
device const float * y = (device const float *) (src1 + offset1);
|
5651
6024
|
|
5652
6025
|
float yl[32];
|
5653
|
-
float sumf[
|
6026
|
+
float sumf[nr0]={0.f};
|
5654
6027
|
|
5655
6028
|
const int nb32 = nb * (QK_K / 32);
|
5656
6029
|
|
5657
|
-
const
|
6030
|
+
const short ix = tiisg;
|
5658
6031
|
|
5659
6032
|
device const float * y4 = y + 32 * ix;
|
5660
6033
|
|
5661
6034
|
iq1m_scale_t scale;
|
5662
6035
|
|
5663
6036
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5664
|
-
|
5665
6037
|
float4 sumy = {0.f};
|
5666
|
-
for (
|
6038
|
+
for (short i = 0; i < 8; ++i) {
|
5667
6039
|
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
5668
6040
|
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
5669
6041
|
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
@@ -5678,7 +6050,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5678
6050
|
device const uint8_t * qh = xr->qh + 2 * ib;
|
5679
6051
|
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
5680
6052
|
|
5681
|
-
for (
|
6053
|
+
for (short row = 0; row < nr0; row++) {
|
5682
6054
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5683
6055
|
|
5684
6056
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
@@ -5687,7 +6059,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5687
6059
|
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
5688
6060
|
|
5689
6061
|
float2 sum = {0.f};
|
5690
|
-
for (
|
6062
|
+
for (short j = 0; j < 4; ++j) {
|
5691
6063
|
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
5692
6064
|
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
5693
6065
|
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
@@ -5709,15 +6081,28 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5709
6081
|
|
5710
6082
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5711
6083
|
|
5712
|
-
for (int row = 0; row <
|
5713
|
-
|
6084
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
6085
|
+
float sum_all = simd_sum(sumf[row]);
|
5714
6086
|
if (tiisg == 0) {
|
5715
|
-
dst_f32[first_row + row] =
|
6087
|
+
dst_f32[first_row + row] = sum_all;
|
5716
6088
|
}
|
5717
6089
|
}
|
5718
6090
|
}
|
5719
6091
|
|
5720
|
-
|
6092
|
+
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
6093
|
+
kernel void kernel_mul_mv_iq1_m_f32(
|
6094
|
+
constant ggml_metal_kargs_mul_mv & args,
|
6095
|
+
device const char * src0,
|
6096
|
+
device const char * src1,
|
6097
|
+
device char * dst,
|
6098
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
6099
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
6100
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
6101
|
+
|
6102
|
+
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
6103
|
+
}
|
6104
|
+
|
6105
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
5721
6106
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
5722
6107
|
args_t args,
|
5723
6108
|
device const char * src0,
|
@@ -5730,10 +6115,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5730
6115
|
|
5731
6116
|
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
5732
6117
|
const int nb = args.ne00/QK4_NL;
|
6118
|
+
|
5733
6119
|
const int r0 = tgpig.x;
|
5734
6120
|
const int r1 = tgpig.y;
|
5735
6121
|
const int im = tgpig.z;
|
5736
|
-
|
6122
|
+
|
6123
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
5737
6124
|
|
5738
6125
|
const uint i12 = im%args.ne12;
|
5739
6126
|
const uint i13 = im/args.ne12;
|
@@ -5744,14 +6131,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5744
6131
|
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
5745
6132
|
device const float * y = (device const float *) (src1 + offset1);
|
5746
6133
|
|
5747
|
-
const
|
5748
|
-
const
|
6134
|
+
const short ix = tiisg/2; // 0...15
|
6135
|
+
const short it = tiisg%2; // 0 or 1
|
5749
6136
|
|
5750
6137
|
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
5751
6138
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5752
6139
|
|
5753
6140
|
float4 yl[4];
|
5754
|
-
float sumf[
|
6141
|
+
float sumf[nr0]={0.f};
|
5755
6142
|
|
5756
6143
|
device const float * yb = y + ix * QK4_NL + it * 8;
|
5757
6144
|
|
@@ -5761,12 +6148,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5761
6148
|
float4 qf1, qf2;
|
5762
6149
|
|
5763
6150
|
for (int ib = ix; ib < nb; ib += 16) {
|
5764
|
-
|
5765
6151
|
device const float4 * y4 = (device const float4 *)yb;
|
5766
|
-
yl[0] = y4[0];
|
5767
|
-
|
5768
|
-
|
6152
|
+
yl[0] = y4[0];
|
6153
|
+
yl[1] = y4[4];
|
6154
|
+
yl[2] = y4[1];
|
6155
|
+
yl[3] = y4[5];
|
5769
6156
|
|
6157
|
+
for (short row = 0; row < nr0; row++) {
|
5770
6158
|
device const block_iq4_nl & xb = x[row*nb + ib];
|
5771
6159
|
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
5772
6160
|
|
@@ -5791,7 +6179,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5791
6179
|
acc1 += acc2;
|
5792
6180
|
|
5793
6181
|
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
5794
|
-
|
5795
6182
|
}
|
5796
6183
|
|
5797
6184
|
yb += 16 * QK4_NL;
|
@@ -5799,15 +6186,29 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5799
6186
|
|
5800
6187
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5801
6188
|
|
5802
|
-
for (int row = 0; row <
|
5803
|
-
|
6189
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
6190
|
+
float sum_all = simd_sum(sumf[row]);
|
5804
6191
|
if (tiisg == 0) {
|
5805
|
-
dst_f32[first_row + row] =
|
6192
|
+
dst_f32[first_row + row] = sum_all;
|
5806
6193
|
}
|
5807
6194
|
}
|
5808
6195
|
}
|
5809
6196
|
|
5810
|
-
|
6197
|
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
6198
|
+
kernel void kernel_mul_mv_iq4_nl_f32(
|
6199
|
+
constant ggml_metal_kargs_mul_mv & args,
|
6200
|
+
device const char * src0,
|
6201
|
+
device const char * src1,
|
6202
|
+
device char * dst,
|
6203
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
6204
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
6205
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
6206
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
6207
|
+
|
6208
|
+
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
6209
|
+
}
|
6210
|
+
|
6211
|
+
template<int nr0, int nsg, int nw, typename args_t>
|
5811
6212
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
5812
6213
|
args_t args,
|
5813
6214
|
device const char * src0,
|
@@ -5823,7 +6224,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5823
6224
|
const int r0 = tgpig.x;
|
5824
6225
|
const int r1 = tgpig.y;
|
5825
6226
|
const int im = tgpig.z;
|
5826
|
-
const int first_row = (r0 *
|
6227
|
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
5827
6228
|
|
5828
6229
|
const uint i12 = im%args.ne12;
|
5829
6230
|
const uint i13 = im/args.ne12;
|
@@ -5834,16 +6235,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5834
6235
|
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
5835
6236
|
device const float * y = (device const float *) (src1 + offset1);
|
5836
6237
|
|
5837
|
-
const
|
5838
|
-
const
|
5839
|
-
const
|
5840
|
-
const
|
6238
|
+
const short ix = tiisg/16; // 0 or 1
|
6239
|
+
const short it = tiisg%16; // 0...15
|
6240
|
+
const short ib = it/2;
|
6241
|
+
const short il = it%2;
|
5841
6242
|
|
5842
6243
|
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
5843
6244
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5844
6245
|
|
5845
6246
|
float4 yl[4];
|
5846
|
-
float sumf[
|
6247
|
+
float sumf[nr0]={0.f};
|
5847
6248
|
|
5848
6249
|
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
5849
6250
|
|
@@ -5854,9 +6255,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5854
6255
|
|
5855
6256
|
for (int ibl = ix; ibl < nb; ibl += 2) {
|
5856
6257
|
device const float4 * y4 = (device const float4 *)yb;
|
5857
|
-
yl[0] = y4[0];
|
6258
|
+
yl[0] = y4[0];
|
6259
|
+
yl[1] = y4[4];
|
6260
|
+
yl[2] = y4[1];
|
6261
|
+
yl[3] = y4[5];
|
5858
6262
|
|
5859
|
-
for (
|
6263
|
+
for (short row = 0; row < nr0; ++row) {
|
5860
6264
|
device const block_iq4_xs & xb = x[row*nb + ibl];
|
5861
6265
|
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
5862
6266
|
|
@@ -5880,7 +6284,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5880
6284
|
|
5881
6285
|
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
|
5882
6286
|
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
5883
|
-
|
5884
6287
|
}
|
5885
6288
|
|
5886
6289
|
yb += 2 * QK_K;
|
@@ -5888,54 +6291,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5888
6291
|
|
5889
6292
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
5890
6293
|
|
5891
|
-
for (int row = 0; row <
|
5892
|
-
|
6294
|
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
6295
|
+
float sum_all = simd_sum(sumf[row]);
|
5893
6296
|
if (tiisg == 0) {
|
5894
|
-
dst_f32[first_row + row] =
|
6297
|
+
dst_f32[first_row + row] = sum_all;
|
5895
6298
|
}
|
5896
6299
|
}
|
5897
6300
|
}
|
5898
6301
|
|
5899
|
-
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5900
|
-
kernel void kernel_mul_mv_iq1_s_f32(
|
5901
|
-
constant ggml_metal_kargs_mul_mv & args,
|
5902
|
-
device const char * src0,
|
5903
|
-
device const char * src1,
|
5904
|
-
device char * dst,
|
5905
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
5906
|
-
ushort tiisg[[thread_index_in_simdgroup]],
|
5907
|
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5908
|
-
|
5909
|
-
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
5910
|
-
}
|
5911
|
-
|
5912
|
-
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
5913
|
-
kernel void kernel_mul_mv_iq1_m_f32(
|
5914
|
-
constant ggml_metal_kargs_mul_mv & args,
|
5915
|
-
device const char * src0,
|
5916
|
-
device const char * src1,
|
5917
|
-
device char * dst,
|
5918
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
5919
|
-
ushort tiisg[[thread_index_in_simdgroup]],
|
5920
|
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5921
|
-
|
5922
|
-
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
5923
|
-
}
|
5924
|
-
|
5925
|
-
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
5926
|
-
kernel void kernel_mul_mv_iq4_nl_f32(
|
5927
|
-
constant ggml_metal_kargs_mul_mv & args,
|
5928
|
-
device const char * src0,
|
5929
|
-
device const char * src1,
|
5930
|
-
device char * dst,
|
5931
|
-
threadgroup char * shmem [[threadgroup(0)]],
|
5932
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
5933
|
-
ushort tiisg[[thread_index_in_simdgroup]],
|
5934
|
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5935
|
-
|
5936
|
-
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5937
|
-
}
|
5938
|
-
|
5939
6302
|
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
5940
6303
|
kernel void kernel_mul_mv_iq4_xs_f32(
|
5941
6304
|
constant ggml_metal_kargs_mul_mv & args,
|
@@ -5947,7 +6310,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
5947
6310
|
ushort tiisg[[thread_index_in_simdgroup]],
|
5948
6311
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
5949
6312
|
|
5950
|
-
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
6313
|
+
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
5951
6314
|
}
|
5952
6315
|
|
5953
6316
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
@@ -5955,28 +6318,21 @@ kernel void kernel_get_rows_q(
|
|
5955
6318
|
device const void * src0,
|
5956
6319
|
device const void * src1,
|
5957
6320
|
device float * dst,
|
5958
|
-
constant
|
5959
|
-
constant uint64_t & nb01,
|
5960
|
-
constant uint64_t & nb02,
|
5961
|
-
constant int64_t & ne10,
|
5962
|
-
constant uint64_t & nb10,
|
5963
|
-
constant uint64_t & nb11,
|
5964
|
-
constant uint64_t & nb1,
|
5965
|
-
constant uint64_t & nb2,
|
6321
|
+
constant ggml_metal_kargs_get_rows & args,
|
5966
6322
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5967
6323
|
uint tiitg[[thread_index_in_threadgroup]],
|
5968
6324
|
uint3 tptg [[threads_per_threadgroup]]) {
|
5969
6325
|
const int64_t i10 = tgpig.x;
|
5970
6326
|
const int64_t i11 = tgpig.y;
|
5971
6327
|
|
5972
|
-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
6328
|
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
5973
6329
|
|
5974
6330
|
const int64_t i02 = i11;
|
5975
6331
|
|
5976
|
-
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
6332
|
+
for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
|
5977
6333
|
float4x4 temp;
|
5978
|
-
dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
5979
|
-
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
6334
|
+
dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
|
6335
|
+
*(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
|
5980
6336
|
}
|
5981
6337
|
}
|
5982
6338
|
|
@@ -5985,27 +6341,20 @@ kernel void kernel_get_rows_f(
|
|
5985
6341
|
device const void * src0,
|
5986
6342
|
device const void * src1,
|
5987
6343
|
device float * dst,
|
5988
|
-
constant
|
5989
|
-
constant uint64_t & nb01,
|
5990
|
-
constant uint64_t & nb02,
|
5991
|
-
constant int64_t & ne10,
|
5992
|
-
constant uint64_t & nb10,
|
5993
|
-
constant uint64_t & nb11,
|
5994
|
-
constant uint64_t & nb1,
|
5995
|
-
constant uint64_t & nb2,
|
6344
|
+
constant ggml_metal_kargs_get_rows & args,
|
5996
6345
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5997
6346
|
uint tiitg[[thread_index_in_threadgroup]],
|
5998
6347
|
uint3 tptg [[threads_per_threadgroup]]) {
|
5999
6348
|
const int64_t i10 = tgpig.x;
|
6000
6349
|
const int64_t i11 = tgpig.y;
|
6001
6350
|
|
6002
|
-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
6351
|
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
6003
6352
|
|
6004
6353
|
const int64_t i02 = i11;
|
6005
6354
|
|
6006
|
-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
6007
|
-
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
6008
|
-
((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
6355
|
+
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
|
6356
|
+
(( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
|
6357
|
+
((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
|
6009
6358
|
}
|
6010
6359
|
}
|
6011
6360
|
|
@@ -6013,27 +6362,20 @@ kernel void kernel_get_rows_i32(
|
|
6013
6362
|
device const void * src0,
|
6014
6363
|
device const void * src1,
|
6015
6364
|
device int32_t * dst,
|
6016
|
-
constant
|
6017
|
-
constant uint64_t & nb01,
|
6018
|
-
constant uint64_t & nb02,
|
6019
|
-
constant int64_t & ne10,
|
6020
|
-
constant uint64_t & nb10,
|
6021
|
-
constant uint64_t & nb11,
|
6022
|
-
constant uint64_t & nb1,
|
6023
|
-
constant uint64_t & nb2,
|
6365
|
+
constant ggml_metal_kargs_get_rows & args,
|
6024
6366
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6025
6367
|
uint tiitg[[thread_index_in_threadgroup]],
|
6026
6368
|
uint3 tptg [[threads_per_threadgroup]]) {
|
6027
6369
|
const int64_t i10 = tgpig.x;
|
6028
6370
|
const int64_t i11 = tgpig.y;
|
6029
6371
|
|
6030
|
-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
6372
|
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
6031
6373
|
|
6032
6374
|
const int64_t i02 = i11;
|
6033
6375
|
|
6034
|
-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
6035
|
-
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
6036
|
-
((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
6376
|
+
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
|
6377
|
+
(( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
|
6378
|
+
((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
|
6037
6379
|
}
|
6038
6380
|
}
|
6039
6381
|
|
@@ -6192,127 +6534,219 @@ kernel void kernel_mul_mm(
|
|
6192
6534
|
}
|
6193
6535
|
}
|
6194
6536
|
|
6195
|
-
|
6196
|
-
|
6197
|
-
|
6198
|
-
|
6199
|
-
|
6200
|
-
|
6201
|
-
|
6202
|
-
|
6203
|
-
|
6204
|
-
|
6205
|
-
|
6206
|
-
|
6207
|
-
|
6208
|
-
|
6209
|
-
|
6210
|
-
|
6211
|
-
|
6212
|
-
|
6213
|
-
|
6214
|
-
|
6215
|
-
|
6537
|
+
template<typename T4>
|
6538
|
+
kernel void kernel_mul_mm_id_map0(
|
6539
|
+
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
6540
|
+
device const char * src1,
|
6541
|
+
device const char * src2,
|
6542
|
+
device char * hsrc1,
|
6543
|
+
device char * htpe,
|
6544
|
+
device char * hids,
|
6545
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
6546
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
6547
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
6548
|
+
const int ide = tgpig[0]; // expert id
|
6549
|
+
|
6550
|
+
int n_all = 0;
|
6551
|
+
|
6552
|
+
device int32_t * ids_i32 = (device int32_t *) (hids);
|
6553
|
+
|
6554
|
+
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
|
6555
|
+
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
|
6556
|
+
|
6557
|
+
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
|
6558
|
+
if (src2_i32[i20] != ide) {
|
6559
|
+
continue;
|
6560
|
+
}
|
6561
|
+
|
6562
|
+
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
|
6563
|
+
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
|
6564
|
+
|
6565
|
+
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
|
6566
|
+
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
|
6567
|
+
}
|
6568
|
+
|
6569
|
+
if (tpitg.x == 0) {
|
6570
|
+
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
|
6571
|
+
}
|
6572
|
+
|
6573
|
+
++n_all;
|
6574
|
+
}
|
6575
|
+
}
|
6576
|
+
|
6577
|
+
if (tpitg.x == 0) {
|
6578
|
+
device int32_t * tpe_i32 = (device int32_t *) (htpe);
|
6579
|
+
tpe_i32[ide] = n_all;
|
6580
|
+
}
|
6581
|
+
}
|
6582
|
+
|
6583
|
+
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
|
6584
|
+
|
6585
|
+
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
|
6586
|
+
|
6587
|
+
template<typename T>
|
6588
|
+
kernel void kernel_mul_mm_id_map1(
|
6589
|
+
constant ggml_metal_kargs_mul_mm_id_map1 & args,
|
6590
|
+
device const char * hdst,
|
6591
|
+
device const char * hids,
|
6592
|
+
device char * dst,
|
6593
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
6594
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
6595
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
6596
|
+
const int i20 = tgpig[0]; // used expert
|
6597
|
+
const int i21 = tgpig[1]; // token
|
6598
|
+
|
6599
|
+
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
6600
|
+
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
|
6601
|
+
|
6602
|
+
const int id = ids_i32[i21*args.ne20 + i20];
|
6603
|
+
|
6604
|
+
const int ide = id / args.neh1;
|
6605
|
+
const int idt = id % args.neh1;
|
6606
|
+
|
6607
|
+
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
|
6608
|
+
|
6609
|
+
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
|
6610
|
+
dst_f32x4[i0] = hdst_f32x4[i0];
|
6611
|
+
}
|
6612
|
+
}
|
6613
|
+
|
6614
|
+
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
|
6615
|
+
|
6616
|
+
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
|
6617
|
+
|
6618
|
+
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
6619
|
+
kernel void kernel_mul_mm_id(
|
6620
|
+
constant ggml_metal_kargs_mul_mm_id & args,
|
6621
|
+
device const char * src0,
|
6622
|
+
device const char * src1,
|
6623
|
+
device const char * tpe,
|
6624
|
+
device char * dst,
|
6625
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
6216
6626
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6217
6627
|
ushort tiitg[[thread_index_in_threadgroup]],
|
6218
6628
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
6219
6629
|
|
6220
|
-
threadgroup
|
6221
|
-
threadgroup
|
6630
|
+
threadgroup T * sa = (threadgroup T *)(shmem);
|
6631
|
+
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
|
6222
6632
|
|
6223
6633
|
const int r0 = tgpig.y;
|
6224
6634
|
const int r1 = tgpig.x;
|
6635
|
+
const int im = tgpig.z;
|
6225
6636
|
|
6226
|
-
|
6637
|
+
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
|
6638
|
+
|
6639
|
+
const int neh1 = tpe_i32[im];
|
6640
|
+
|
6641
|
+
if (r1*BLOCK_SIZE_N >= neh1) {
|
6642
|
+
return;
|
6643
|
+
}
|
6227
6644
|
|
6228
6645
|
// if this block is of 64x32 shape or smaller
|
6229
|
-
short n_rows = (
|
6230
|
-
short n_cols = (
|
6646
|
+
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
6647
|
+
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
6231
6648
|
|
6232
6649
|
// a thread shouldn't load data outside of the matrix
|
6233
|
-
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
6234
|
-
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
6650
|
+
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
6651
|
+
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
6235
6652
|
|
6236
|
-
|
6237
|
-
|
6653
|
+
simdgroup_T8x8 ma[4];
|
6654
|
+
simdgroup_half8x8 mb[2];
|
6238
6655
|
simdgroup_float8x8 mc[8];
|
6239
|
-
|
6656
|
+
|
6657
|
+
for (short i = 0; i < 8; i++){
|
6240
6658
|
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
6241
6659
|
}
|
6660
|
+
|
6242
6661
|
short il = (tiitg % THREAD_PER_ROW);
|
6243
6662
|
|
6244
|
-
|
6663
|
+
const int i12 = im%args.neh12;
|
6664
|
+
const int i13 = im/args.neh12;
|
6245
6665
|
|
6246
|
-
|
6666
|
+
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
6667
|
+
const short offset1 = il/nl;
|
6247
6668
|
|
6248
|
-
device const block_q * x = (device const block_q *)(src0
|
6249
|
-
|
6250
|
-
|
6251
|
-
|
6252
|
-
+
|
6669
|
+
device const block_q * x = (device const block_q *)(src0
|
6670
|
+
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
6671
|
+
|
6672
|
+
device const half * y = (device const half *)(src1
|
6673
|
+
+ args.nbh13*i13
|
6674
|
+
+ args.nbh12*i12
|
6675
|
+
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
|
6676
|
+
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
6253
6677
|
|
6254
|
-
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
6678
|
+
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
6255
6679
|
// load data and store to threadgroup memory
|
6256
|
-
|
6680
|
+
T4x4 temp_a;
|
6257
6681
|
dequantize_func(x, il, temp_a);
|
6682
|
+
|
6258
6683
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
6259
6684
|
|
6260
|
-
|
6261
|
-
|
6262
|
-
|
6263
|
-
+ (tiitg
|
6685
|
+
#pragma unroll(16)
|
6686
|
+
for (short i = 0; i < 16; i++) {
|
6687
|
+
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
6688
|
+
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
6689
|
+
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
6264
6690
|
}
|
6265
6691
|
|
6266
|
-
*(threadgroup
|
6692
|
+
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
|
6267
6693
|
|
6268
6694
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
6269
|
-
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
6695
|
+
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
6270
6696
|
y += BLOCK_SIZE_K;
|
6271
6697
|
|
6272
6698
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
6273
6699
|
|
6274
6700
|
// load matrices from threadgroup memory and conduct outer products
|
6275
|
-
threadgroup
|
6276
|
-
threadgroup
|
6701
|
+
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
6702
|
+
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
6277
6703
|
|
6278
|
-
#pragma unroll(
|
6279
|
-
for (
|
6704
|
+
#pragma unroll(4)
|
6705
|
+
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
6280
6706
|
#pragma unroll(4)
|
6281
|
-
for (
|
6707
|
+
for (short i = 0; i < 4; i++) {
|
6282
6708
|
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
6283
6709
|
}
|
6710
|
+
|
6284
6711
|
simdgroup_barrier(mem_flags::mem_none);
|
6712
|
+
|
6285
6713
|
#pragma unroll(2)
|
6286
|
-
for (
|
6714
|
+
for (short i = 0; i < 2; i++) {
|
6287
6715
|
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
6288
6716
|
}
|
6289
6717
|
|
6290
|
-
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
6291
|
-
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
6292
|
-
|
6293
6718
|
#pragma unroll(8)
|
6294
|
-
for (
|
6719
|
+
for (short i = 0; i < 8; i++){
|
6295
6720
|
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
6296
6721
|
}
|
6722
|
+
|
6723
|
+
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
|
6724
|
+
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
|
6297
6725
|
}
|
6298
6726
|
}
|
6299
6727
|
|
6300
|
-
{
|
6728
|
+
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
|
6729
|
+
device float * C = (device float *) dst +
|
6730
|
+
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
6731
|
+
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
|
6732
|
+
|
6733
|
+
for (short i = 0; i < 8; i++) {
|
6734
|
+
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
|
6735
|
+
}
|
6736
|
+
} else {
|
6737
|
+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
6301
6738
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
6302
6739
|
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
6303
|
-
|
6304
|
-
for (
|
6305
|
-
simdgroup_store(mc[i], temp_str + 8
|
6740
|
+
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
6741
|
+
for (short i = 0; i < 8; i++) {
|
6742
|
+
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
6306
6743
|
}
|
6307
6744
|
|
6308
6745
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
6309
6746
|
|
6310
6747
|
if (sgitg == 0) {
|
6311
6748
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
6312
|
-
|
6313
|
-
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
|
6314
|
-
|
6315
|
-
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
|
6749
|
+
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
|
6316
6750
|
device float4 * D4 = (device float4 *) D;
|
6317
6751
|
|
6318
6752
|
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
@@ -6332,66 +6766,6 @@ void kernel_mul_mm_id_impl(
|
|
6332
6766
|
}
|
6333
6767
|
}
|
6334
6768
|
|
6335
|
-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
6336
|
-
kernel void kernel_mul_mm_id(
|
6337
|
-
constant ggml_metal_kargs_mul_mm_id & args,
|
6338
|
-
device const char * src0s,
|
6339
|
-
device const char * src1,
|
6340
|
-
device char * dst,
|
6341
|
-
device const char * ids,
|
6342
|
-
threadgroup char * shmem [[threadgroup(0)]],
|
6343
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
6344
|
-
ushort tiitg[[thread_index_in_threadgroup]],
|
6345
|
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
6346
|
-
|
6347
|
-
const int32_t i02 = tgpig.z;
|
6348
|
-
|
6349
|
-
tgpig.z = 0;
|
6350
|
-
|
6351
|
-
device const char * src0 = src0s + i02*args.nb02;
|
6352
|
-
|
6353
|
-
// row indices
|
6354
|
-
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
6355
|
-
|
6356
|
-
// TODO: parallelize this loop
|
6357
|
-
int32_t _ne1 = 0;
|
6358
|
-
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
6359
|
-
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
6360
|
-
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
6361
|
-
if (id == i02) {
|
6362
|
-
if (tiitg == 0) {
|
6363
|
-
rowids[_ne1] = ushort2(ii0, ii1);
|
6364
|
-
}
|
6365
|
-
_ne1++;
|
6366
|
-
}
|
6367
|
-
}
|
6368
|
-
}
|
6369
|
-
|
6370
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
6371
|
-
|
6372
|
-
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
6373
|
-
args.ne00,
|
6374
|
-
args.ne02,
|
6375
|
-
args.nb01,
|
6376
|
-
args.nb02,
|
6377
|
-
args.ne11,
|
6378
|
-
args.ne12,
|
6379
|
-
args.nb10,
|
6380
|
-
args.nb11,
|
6381
|
-
args.nb12,
|
6382
|
-
args.ne0,
|
6383
|
-
_ne1,
|
6384
|
-
(int64_t)args.ne0*args.ne1,
|
6385
|
-
src0,
|
6386
|
-
src1,
|
6387
|
-
rowids,
|
6388
|
-
dst,
|
6389
|
-
shmem,
|
6390
|
-
tgpig,
|
6391
|
-
tiitg,
|
6392
|
-
sgitg);
|
6393
|
-
}
|
6394
|
-
|
6395
6769
|
#define QK_NL 16
|
6396
6770
|
|
6397
6771
|
//
|
@@ -6432,63 +6806,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
|
|
6432
6806
|
// matrix-matrix multiplication
|
6433
6807
|
//
|
6434
6808
|
|
6435
|
-
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>)
|
6809
|
+
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
|
6436
6810
|
|
6437
|
-
template [[host_name("kernel_mul_mm_f32_f32")]] kernel
|
6438
|
-
template [[host_name("kernel_mul_mm_f16_f32")]] kernel
|
6811
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
6812
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
6439
6813
|
#if defined(GGML_METAL_USE_BF16)
|
6440
|
-
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel
|
6814
|
+
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
6441
6815
|
#endif
|
6442
|
-
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel
|
6443
|
-
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel
|
6444
|
-
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel
|
6445
|
-
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel
|
6446
|
-
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel
|
6447
|
-
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel
|
6448
|
-
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel
|
6449
|
-
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel
|
6450
|
-
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel
|
6451
|
-
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel
|
6452
|
-
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel
|
6453
|
-
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel
|
6454
|
-
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel
|
6455
|
-
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel
|
6456
|
-
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel
|
6457
|
-
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel
|
6458
|
-
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel
|
6459
|
-
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel
|
6460
|
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel
|
6816
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
6817
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
6818
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
6819
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
6820
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
6821
|
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
6822
|
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
6823
|
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
6824
|
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
6825
|
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
6826
|
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
6827
|
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
6828
|
+
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
6829
|
+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
6830
|
+
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
6831
|
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6832
|
+
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6833
|
+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
6834
|
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6461
6835
|
|
6462
6836
|
//
|
6463
6837
|
// indirect matrix-matrix multiplication
|
6464
6838
|
//
|
6465
6839
|
|
6466
|
-
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>)
|
6840
|
+
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
|
6467
6841
|
|
6468
|
-
template [[host_name("
|
6469
|
-
template [[host_name("
|
6842
|
+
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
6843
|
+
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
6470
6844
|
#if defined(GGML_METAL_USE_BF16)
|
6471
|
-
template [[host_name("
|
6845
|
+
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
6472
6846
|
#endif
|
6473
|
-
template [[host_name("
|
6474
|
-
template [[host_name("
|
6475
|
-
template [[host_name("
|
6476
|
-
template [[host_name("
|
6477
|
-
template [[host_name("
|
6478
|
-
template [[host_name("
|
6479
|
-
template [[host_name("
|
6480
|
-
template [[host_name("
|
6481
|
-
template [[host_name("
|
6482
|
-
template [[host_name("
|
6483
|
-
template [[host_name("
|
6484
|
-
template [[host_name("
|
6485
|
-
template [[host_name("
|
6486
|
-
template [[host_name("
|
6487
|
-
template [[host_name("
|
6488
|
-
template [[host_name("
|
6489
|
-
template [[host_name("
|
6490
|
-
template [[host_name("
|
6491
|
-
template [[host_name("
|
6847
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
6848
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
6849
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
6850
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
6851
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
6852
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
6853
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
6854
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
6855
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
6856
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
6857
|
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
6858
|
+
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
6859
|
+
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
6860
|
+
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
6861
|
+
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
6862
|
+
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6863
|
+
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6864
|
+
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
6865
|
+
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6866
|
+
|
6492
6867
|
|
6493
6868
|
//
|
6494
6869
|
// matrix-vector multiplication
|
@@ -6612,121 +6987,103 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t
|
|
6612
6987
|
#if defined(GGML_METAL_USE_BF16)
|
6613
6988
|
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
6614
6989
|
#endif
|
6615
|
-
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl
|
6616
|
-
|
6617
|
-
template [[host_name("
|
6618
|
-
template [[host_name("
|
6619
|
-
template [[host_name("
|
6620
|
-
template [[host_name("
|
6621
|
-
|
6622
|
-
template [[host_name("
|
6623
|
-
template [[host_name("
|
6624
|
-
template [[host_name("
|
6625
|
-
template [[host_name("
|
6626
|
-
template [[host_name("
|
6627
|
-
template [[host_name("
|
6628
|
-
template [[host_name("
|
6629
|
-
template [[host_name("
|
6630
|
-
template [[host_name("
|
6631
|
-
template [[host_name("
|
6632
|
-
template [[host_name("
|
6633
|
-
template [[host_name("
|
6990
|
+
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
|
6991
|
+
|
6992
|
+
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
|
6993
|
+
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
|
6994
|
+
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
|
6995
|
+
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
|
6996
|
+
|
6997
|
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
|
6998
|
+
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
|
6999
|
+
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
|
7000
|
+
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH>>>;
|
7001
|
+
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH>>>;
|
7002
|
+
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH>>>;
|
7003
|
+
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH>>>;
|
7004
|
+
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
|
7005
|
+
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH>>>;
|
7006
|
+
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
|
7007
|
+
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH>>>;
|
7008
|
+
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH>>>;
|
7009
|
+
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
|
7010
|
+
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
|
6634
7011
|
|
6635
7012
|
kernel void kernel_pool_2d_max_f32(
|
6636
7013
|
device const float * src0,
|
6637
7014
|
device float * dst,
|
6638
|
-
constant
|
6639
|
-
constant int32_t & k1,
|
6640
|
-
constant int32_t & s0,
|
6641
|
-
constant int32_t & s1,
|
6642
|
-
constant int32_t & p0,
|
6643
|
-
constant int32_t & p1,
|
6644
|
-
constant int64_t & IH,
|
6645
|
-
constant int64_t & IW,
|
6646
|
-
constant int64_t & OH,
|
6647
|
-
constant int64_t & OW,
|
6648
|
-
constant int64_t & parallel_elements,
|
7015
|
+
constant ggml_metal_kargs_pool_2d & args,
|
6649
7016
|
uint gid[[thread_position_in_grid]]) {
|
6650
7017
|
|
6651
|
-
if (gid >= parallel_elements) {
|
7018
|
+
if (gid >= args.parallel_elements) {
|
6652
7019
|
return;
|
6653
7020
|
}
|
6654
7021
|
|
6655
7022
|
const int idx = gid;
|
6656
|
-
const int I_HW = IH * IW;
|
6657
|
-
const int O_HW = OH * OW;
|
7023
|
+
const int I_HW = args.IH * args.IW;
|
7024
|
+
const int O_HW = args.OH * args.OW;
|
6658
7025
|
const int nc = idx / O_HW;
|
6659
|
-
const int cur_oh = idx % O_HW / OW;
|
6660
|
-
const int cur_ow = idx % O_HW % OW;
|
7026
|
+
const int cur_oh = idx % O_HW / args.OW;
|
7027
|
+
const int cur_ow = idx % O_HW % args.OW;
|
6661
7028
|
|
6662
7029
|
device const float * i_ptr = src0 + nc * I_HW;
|
6663
7030
|
device float * o_ptr = dst + nc * O_HW;
|
6664
7031
|
|
6665
|
-
const int start_h = cur_oh * s1 - p1;
|
7032
|
+
const int start_h = cur_oh * args.s1 - args.p1;
|
6666
7033
|
const int bh = MAX(0, start_h);
|
6667
|
-
const int eh = MIN(IH, start_h + k1);
|
6668
|
-
const int start_w = cur_ow * s0 - p0;
|
7034
|
+
const int eh = MIN(args.IH, start_h + args.k1);
|
7035
|
+
const int start_w = cur_ow * args.s0 - args.p0;
|
6669
7036
|
const int bw = MAX(0, start_w);
|
6670
|
-
const int ew = MIN(IW, start_w + k0);
|
7037
|
+
const int ew = MIN(args.IW, start_w + args.k0);
|
6671
7038
|
|
6672
7039
|
float res = -INFINITY;
|
6673
7040
|
|
6674
7041
|
for (int i = bh; i < eh; i += 1) {
|
6675
7042
|
for (int j = bw; j < ew; j += 1) {
|
6676
|
-
res = MAX(res, i_ptr[i * IW + j]);
|
7043
|
+
res = MAX(res, i_ptr[i * args.IW + j]);
|
6677
7044
|
}
|
6678
7045
|
}
|
6679
7046
|
|
6680
|
-
o_ptr[cur_oh * OW + cur_ow] = res;
|
7047
|
+
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
6681
7048
|
}
|
6682
7049
|
|
6683
7050
|
kernel void kernel_pool_2d_avg_f32(
|
6684
7051
|
device const float * src0,
|
6685
7052
|
device float * dst,
|
6686
|
-
constant
|
6687
|
-
constant int32_t & k1,
|
6688
|
-
constant int32_t & s0,
|
6689
|
-
constant int32_t & s1,
|
6690
|
-
constant int32_t & p0,
|
6691
|
-
constant int32_t & p1,
|
6692
|
-
constant int64_t & IH,
|
6693
|
-
constant int64_t & IW,
|
6694
|
-
constant int64_t & OH,
|
6695
|
-
constant int64_t & OW,
|
6696
|
-
constant int64_t & parallel_elements,
|
7053
|
+
constant ggml_metal_kargs_pool_2d & args,
|
6697
7054
|
uint gid[[thread_position_in_grid]]) {
|
6698
7055
|
|
6699
|
-
if (gid >= parallel_elements) {
|
7056
|
+
if (gid >= args.parallel_elements) {
|
6700
7057
|
return;
|
6701
7058
|
}
|
6702
7059
|
|
6703
7060
|
const int idx = gid;
|
6704
|
-
const int I_HW = IH * IW;
|
6705
|
-
const int O_HW = OH * OW;
|
7061
|
+
const int I_HW = args.IH * args.IW;
|
7062
|
+
const int O_HW = args.OH * args.OW;
|
6706
7063
|
const int nc = idx / O_HW;
|
6707
|
-
const int cur_oh = idx % O_HW / OW;
|
6708
|
-
const int cur_ow = idx % O_HW % OW;
|
7064
|
+
const int cur_oh = idx % O_HW / args.OW;
|
7065
|
+
const int cur_ow = idx % O_HW % args.OW;
|
6709
7066
|
|
6710
7067
|
device const float * i_ptr = src0 + nc * I_HW;
|
6711
7068
|
device float * o_ptr = dst + nc * O_HW;
|
6712
7069
|
|
6713
|
-
const int start_h = cur_oh * s1 - p1;
|
7070
|
+
const int start_h = cur_oh * args.s1 - args.p1;
|
6714
7071
|
const int bh = MAX(0, start_h);
|
6715
|
-
const int eh = MIN(IH, start_h + k1);
|
6716
|
-
const int start_w = cur_ow * s0 - p0;
|
7072
|
+
const int eh = MIN(args.IH, start_h + args.k1);
|
7073
|
+
const int start_w = cur_ow * args.s0 - args.p0;
|
6717
7074
|
const int bw = MAX(0, start_w);
|
6718
|
-
const int ew = MIN(IW, start_w + k0);
|
7075
|
+
const int ew = MIN(args.IW, start_w + args.k0);
|
6719
7076
|
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
6720
|
-
const float scale = 1. / (k0 * k1);
|
7077
|
+
const float scale = 1. / (args.k0 * args.k1);
|
6721
7078
|
|
6722
7079
|
float res = 0;
|
6723
7080
|
|
6724
7081
|
for (int i = bh; i < eh; i += 1) {
|
6725
7082
|
for (int j = bw; j < ew; j += 1) {
|
6726
|
-
float cur = i_ptr[i * IW + j];
|
7083
|
+
float cur = i_ptr[i * args.IW + j];
|
6727
7084
|
res += cur * scale;
|
6728
7085
|
}
|
6729
7086
|
}
|
6730
7087
|
|
6731
|
-
o_ptr[cur_oh * OW + cur_ow] = res;
|
7088
|
+
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
6732
7089
|
}
|