whispercpp 1.3.1 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +7 -3
- data/README.md +161 -43
- data/Rakefile +45 -13
- data/ext/.gitignore +4 -8
- data/ext/dependencies.rb +73 -0
- data/ext/extconf.rb +21 -198
- data/ext/options.rb +85 -0
- data/ext/ruby_whisper.c +177 -0
- data/ext/ruby_whisper.h +17 -2
- data/ext/ruby_whisper_context.c +672 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1303 -0
- data/ext/ruby_whisper_segment.c +220 -0
- data/ext/ruby_whisper_transcribe.cpp +93 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +255 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +133 -0
- data/ext/sources/examples/addon.node/addon.cpp +557 -0
- data/ext/sources/examples/addon.node/index.js +57 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +176 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1295 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +800 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +175 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +469 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +226 -0
- data/ext/sources/examples/server/CMakeLists.txt +15 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1238 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +435 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +43 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1914 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +464 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +843 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +147 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +685 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +59 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2845 -0
- data/ext/sources/examples/talk-llama/llama-context.h +297 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1693 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +710 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +103 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +207 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +44 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +439 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +59 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +116 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1163 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +282 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +15114 -0
- data/ext/sources/examples/talk-llama/llama-model.h +452 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +1049 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3377 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +132 -0
- data/ext/sources/examples/talk-llama/llama.cpp +358 -0
- data/ext/sources/examples/talk-llama/llama.h +1484 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +810 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +854 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +149 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +251 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +435 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +50 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-backend.h +10 -8
- data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
- data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +11 -1
- data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
- data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
- data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
- data/ext/{ggml → sources/ggml}/include/ggml.h +325 -269
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +404 -0
- data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
- data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +92 -53
- data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +69 -34
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +75 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
- data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +140 -1
- data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +588 -146
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
- data/ext/{ggml → sources/ggml}/src/ggml-common.h +16 -8
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +597 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +3 -2
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/{ggml/src/ggml-cpu/cpu-feats-x86.cpp → sources/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp} +5 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +3285 -0
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +73 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +172 -41
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3551 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +78 -25
- data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.cpp → sources/ggml/src/ggml-cpu/hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3594 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +19 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +9786 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +118 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/{ggml/src/ggml-cpu/ggml-cpu-quants.h → sources/ggml/src/ggml-cpu/quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +1184 -0
- data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.cpp → sources/ggml/src/ggml-cpu/traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +345 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +1027 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +851 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +752 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +31 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1474 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +638 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3647 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +506 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +155 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +26 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +378 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +66 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-impl.h +147 -158
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +121 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +649 -0
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2504 -1108
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +2102 -1463
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +110 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +6494 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
- data/ext/{ggml → sources/ggml}/src/ggml-quants.c +120 -128
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +494 -84
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +344 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +561 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +56 -70
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +8 -12
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +575 -0
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +839 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +823 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +188 -67
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2987 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1120 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +84 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +102 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +212 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1197 -1295
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +60 -81
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1065 -0
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +482 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +111 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +472 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +38 -28
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +15 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +6 -11
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1307 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +289 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +200 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3822 -1335
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +61 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +203 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/{ggml → sources/ggml}/src/ggml.c +918 -1782
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +1351 -0
- data/ext/{include → sources/include}/whisper.h +70 -2
- data/ext/sources/src/CMakeLists.txt +145 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +36 -10
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +29 -3
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{src → sources/src}/whisper.cpp +1966 -386
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +39 -5
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +202 -126
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +510 -0
- data/test/helper.rb +24 -0
- data/{tests → test}/test_callback.rb +45 -3
- data/{tests → test}/test_error.rb +2 -2
- data/{tests → test}/test_model.rb +47 -0
- data/test/test_package.rb +51 -0
- data/test/test_params.rb +297 -0
- data/test/test_segment.rb +146 -0
- data/test/test_vad.rb +19 -0
- data/test/test_vad_params.rb +103 -0
- data/{tests → test}/test_whisper.rb +106 -36
- data/whispercpp.gemspec +5 -5
- metadata +837 -134
- data/ext/cpu.mk +0 -9
- data/ext/examples/dr_wav.h +0 -8815
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -10835
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
- data/ext/ggml/src/ggml-sycl/convert.cpp +0 -547
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
- data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +0 -1015
- data/ext/ggml/src/ggml-sycl/norm.cpp +0 -378
- data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
- data/ext/metal-embed.mk +0 -17
- data/ext/metal.mk +0 -6
- data/ext/ruby_whisper.cpp +0 -1909
- data/ext/scripts/get-flags.mk +0 -38
- data/lib/whisper.rb +0 -2
- data/tests/helper.rb +0 -7
- data/tests/test_package.rb +0 -31
- data/tests/test_params.rb +0 -160
- data/tests/test_segment.rb +0 -83
- /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
- /data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.h → sources/ggml/src/ggml-cpu/hbm.h} +0 -0
- /data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.h → sources/ggml/src/ggml-cpu/traits.h} +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
- /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
@@ -37,10 +37,20 @@
|
|
37
37
|
#include "ggml-backend-impl.h"
|
38
38
|
|
39
39
|
#include "ggml-sycl/backend.hpp"
|
40
|
+
#include "ggml-sycl/common.hpp"
|
41
|
+
#include "ggml-sycl/element_wise.hpp"
|
40
42
|
#include "ggml-sycl/presets.hpp"
|
41
43
|
#include "ggml-sycl/gemm.hpp"
|
44
|
+
#include "ggml-sycl/sycl_hw.hpp"
|
45
|
+
#include "ggml-sycl/getrows.hpp"
|
46
|
+
#include "ggml.h"
|
42
47
|
|
43
48
|
static bool g_sycl_loaded = false;
|
49
|
+
int g_ggml_sycl_debug = 0;
|
50
|
+
int g_ggml_sycl_disable_optimize = 0;
|
51
|
+
int g_ggml_sycl_disable_graph = 0;
|
52
|
+
int g_ggml_sycl_disable_dnn = 0;
|
53
|
+
int g_ggml_sycl_prioritize_dmmv = 0;
|
44
54
|
|
45
55
|
static ggml_sycl_device_info ggml_sycl_init() {
|
46
56
|
ggml_sycl_device_info info = {};
|
@@ -54,30 +64,26 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
|
54
64
|
GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
|
55
65
|
|
56
66
|
int64_t total_vram = 0;
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
#
|
63
|
-
GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
|
64
|
-
#else
|
65
|
-
GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
|
66
|
-
#endif
|
67
|
-
GGML_LOG_INFO("%s: found %d %s devices:\n", __func__, info.device_count, GGML_SYCL_NAME);
|
68
|
-
|
67
|
+
/* This is a bit misleading; reserved for later */
|
68
|
+
// #if defined(SYCL_USE_XMX)
|
69
|
+
// GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
|
70
|
+
// #else
|
71
|
+
// GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
|
72
|
+
// #endif
|
69
73
|
for (int i = 0; i < info.device_count; ++i) {
|
70
74
|
info.devices[i].vmm = 0;
|
71
75
|
dpct::device_info prop;
|
76
|
+
sycl::device device = dpct::dev_mgr::instance().get_device(i);
|
77
|
+
|
72
78
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
|
73
|
-
prop,
|
79
|
+
prop, device)));
|
74
80
|
|
75
81
|
info.default_tensor_split[i] = total_vram;
|
76
82
|
total_vram += prop.get_global_mem_size();
|
77
83
|
|
78
84
|
info.devices[i].cc =
|
79
85
|
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
80
|
-
|
86
|
+
info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
81
87
|
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
82
88
|
}
|
83
89
|
|
@@ -92,7 +98,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
|
|
92
98
|
return info;
|
93
99
|
}
|
94
100
|
|
95
|
-
void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
101
|
+
static void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
96
102
|
|
97
103
|
dpct::device_info prop;
|
98
104
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
@@ -109,13 +115,33 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
|
109
115
|
name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
|
110
116
|
|
111
117
|
auto global_mem_size = prop.get_global_mem_size()/1000000;
|
112
|
-
|
113
118
|
GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
|
114
119
|
name.c_str(), version.c_str(), prop.get_max_compute_units(),
|
115
120
|
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
|
116
121
|
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
117
122
|
}
|
118
123
|
|
124
|
+
static void print_device_opt_feature(int device_count) {
|
125
|
+
GGML_LOG_INFO("SYCL Optimization Feature:\n");
|
126
|
+
GGML_LOG_INFO(
|
127
|
+
"|ID| Device Type|Reorder|\n");
|
128
|
+
GGML_LOG_INFO(
|
129
|
+
"|--|-------------------|-------|\n");
|
130
|
+
std::map<std::string, size_t> DeviceNums;
|
131
|
+
for (int id = 0; id < device_count; ++id) {
|
132
|
+
sycl::device device = dpct::dev_mgr::instance().get_device(id);
|
133
|
+
std::string backend_type = get_device_backend_and_type(device);
|
134
|
+
int type_id = DeviceNums[backend_type]++;
|
135
|
+
std::stringstream device_type;
|
136
|
+
device_type << "[" << backend_type << ":" << std::to_string(type_id)
|
137
|
+
<< "]";
|
138
|
+
std::string device_type_s = device_type.str();
|
139
|
+
device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
|
140
|
+
GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
|
141
|
+
ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
|
142
|
+
}
|
143
|
+
|
144
|
+
}
|
119
145
|
void ggml_backend_sycl_print_sycl_devices() {
|
120
146
|
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
|
121
147
|
int device_count = dpct::dev_mgr::instance().device_count();
|
@@ -144,6 +170,8 @@ void ggml_backend_sycl_print_sycl_devices() {
|
|
144
170
|
<< "]";
|
145
171
|
print_device_detail(id, device, device_type.str());
|
146
172
|
}
|
173
|
+
|
174
|
+
print_device_opt_feature(device_count);
|
147
175
|
}
|
148
176
|
|
149
177
|
static inline int get_sycl_env(const char *env_name, int default_val) {
|
@@ -164,14 +192,36 @@ static void ggml_check_sycl() try {
|
|
164
192
|
static bool initialized = false;
|
165
193
|
|
166
194
|
if (!initialized) {
|
167
|
-
GGML_LOG_INFO("[SYCL] call ggml_check_sycl\n");
|
168
195
|
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
169
|
-
|
170
|
-
|
196
|
+
g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
|
197
|
+
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
198
|
+
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
199
|
+
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
200
|
+
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
201
|
+
GGML_LOG_INFO("Running with Environment Variables:\n");
|
202
|
+
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
203
|
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
204
|
+
#ifdef GGML_SYCL_GRAPH
|
205
|
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
206
|
+
#else
|
207
|
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
|
208
|
+
#endif
|
209
|
+
#if GGML_SYCL_DNNL
|
210
|
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
|
211
|
+
#else
|
212
|
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
213
|
+
#endif
|
214
|
+
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
215
|
+
GGML_LOG_INFO("Build with Macros:\n");
|
216
|
+
#if defined(GGML_SYCL_FORCE_MMQ)
|
217
|
+
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
218
|
+
#else
|
219
|
+
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
|
220
|
+
#endif
|
171
221
|
#if defined(GGML_SYCL_F16)
|
172
|
-
GGML_LOG_INFO("
|
222
|
+
GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
|
173
223
|
#else
|
174
|
-
GGML_LOG_INFO("
|
224
|
+
GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
|
175
225
|
#endif
|
176
226
|
|
177
227
|
/* NOT REMOVE, keep it for next optimize for XMX.
|
@@ -243,19 +293,27 @@ struct ggml_backend_sycl_buffer_context {
|
|
243
293
|
void * dev_ptr = nullptr;
|
244
294
|
queue_ptr stream;
|
245
295
|
std::string name;
|
296
|
+
optimize_feature opt_feature;
|
297
|
+
std::vector<ggml_tensor_extra_gpu *> tensor_extras;
|
246
298
|
|
247
|
-
|
299
|
+
ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
|
248
300
|
device(device), dev_ptr(dev_ptr), stream(stream) {
|
249
301
|
check_allow_gpu_index(device);
|
250
302
|
name = (GGML_SYCL_NAME + std::to_string(device));
|
303
|
+
opt_feature = ggml_sycl_info().devices[device].opt_feature;
|
251
304
|
}
|
252
305
|
|
253
|
-
|
254
306
|
~ggml_backend_sycl_buffer_context() {
|
255
307
|
if (dev_ptr != nullptr) {
|
256
308
|
ggml_sycl_set_device(device);
|
257
309
|
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
|
258
310
|
}
|
311
|
+
|
312
|
+
//release extra used by tensors
|
313
|
+
for (ggml_tensor_extra_gpu * extra : tensor_extras) {
|
314
|
+
release_extra_gpu(extra);
|
315
|
+
}
|
316
|
+
|
259
317
|
}
|
260
318
|
};
|
261
319
|
|
@@ -283,18 +341,23 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
283
341
|
return ctx->dev_ptr;
|
284
342
|
}
|
285
343
|
|
286
|
-
static
|
344
|
+
static enum ggml_status
|
287
345
|
ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
288
346
|
ggml_tensor *tensor) try {
|
347
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
348
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
|
289
349
|
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
|
290
350
|
|
291
|
-
if (tensor->view_src != NULL
|
351
|
+
if (tensor->view_src != NULL) {
|
292
352
|
assert(tensor->view_src->buffer->buft == buffer->buft);
|
293
|
-
|
294
|
-
|
295
|
-
|
353
|
+
return GGML_STATUS_SUCCESS;
|
354
|
+
}
|
355
|
+
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
|
356
|
+
!g_ggml_sycl_disable_optimize) {
|
357
|
+
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
358
|
+
tensor->extra = extra;
|
359
|
+
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
296
360
|
}
|
297
|
-
|
298
361
|
|
299
362
|
if (ggml_is_quantized(tensor->type)) {
|
300
363
|
// initialize padding to 0 to avoid possible NaN values
|
@@ -307,6 +370,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
|
307
370
|
padded_size - original_size).wait()));
|
308
371
|
}
|
309
372
|
}
|
373
|
+
return GGML_STATUS_SUCCESS;
|
310
374
|
}
|
311
375
|
catch (sycl::exception const &exc) {
|
312
376
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
@@ -318,19 +382,23 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
318
382
|
ggml_tensor *tensor,
|
319
383
|
const void *data, size_t offset,
|
320
384
|
size_t size) try {
|
321
|
-
|
385
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
386
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
|
387
|
+
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
|
322
388
|
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
|
323
|
-
|
324
389
|
ggml_sycl_set_device(ctx->device);
|
325
390
|
auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
|
326
|
-
SYCL_CHECK(
|
327
|
-
|
328
|
-
|
391
|
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
392
|
+
#ifndef _WIN32
|
393
|
+
// Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.
|
394
|
+
// This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.
|
395
|
+
char * host_buf = (char *) malloc(size);
|
329
396
|
memcpy(host_buf, data, size);
|
330
|
-
SYCL_CHECK(
|
331
|
-
CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
|
332
|
-
.wait()));
|
397
|
+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));
|
333
398
|
free(host_buf);
|
399
|
+
#else
|
400
|
+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));
|
401
|
+
#endif
|
334
402
|
}
|
335
403
|
catch (sycl::exception const &exc) {
|
336
404
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
@@ -342,7 +410,9 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|
342
410
|
const ggml_tensor *tensor,
|
343
411
|
void *data, size_t offset,
|
344
412
|
size_t size) try {
|
345
|
-
|
413
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
414
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
|
415
|
+
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
|
346
416
|
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
|
347
417
|
|
348
418
|
ggml_sycl_set_device(ctx->device);
|
@@ -358,7 +428,7 @@ catch (sycl::exception const &exc) {
|
|
358
428
|
std::exit(1);
|
359
429
|
}
|
360
430
|
|
361
|
-
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
431
|
+
static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
362
432
|
const void *ptr_src, size_t size) {
|
363
433
|
char *host_buf = (char *)malloc(size);
|
364
434
|
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
|
@@ -370,7 +440,12 @@ static bool
|
|
370
440
|
ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
|
371
441
|
const ggml_tensor *src,
|
372
442
|
ggml_tensor *dst) try {
|
373
|
-
|
443
|
+
bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer);
|
444
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
445
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
|
446
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
|
447
|
+
GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
|
448
|
+
if (is_cpy_supported) {
|
374
449
|
ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
|
375
450
|
ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
|
376
451
|
|
@@ -427,7 +502,8 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
|
|
427
502
|
|
428
503
|
static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
|
429
504
|
uint8_t value) try {
|
430
|
-
|
505
|
+
GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size);
|
506
|
+
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
|
431
507
|
|
432
508
|
ggml_sycl_set_device(ctx->device);
|
433
509
|
queue_ptr stream = ctx->stream;
|
@@ -444,16 +520,51 @@ catch (sycl::exception const &exc) {
|
|
444
520
|
std::exit(1);
|
445
521
|
}
|
446
522
|
|
523
|
+
static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
|
524
|
+
size_t offset, size_t size) {
|
525
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
526
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
|
527
|
+
GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value);
|
528
|
+
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
|
529
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx->device));
|
530
|
+
auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
|
531
|
+
if (size == 0) {
|
532
|
+
return; // Nothing to do
|
533
|
+
}
|
534
|
+
if (tensor->data == nullptr) {
|
535
|
+
GGML_ABORT("Error: Tensor data pointer is null.\n");
|
536
|
+
}
|
537
|
+
void * target_ptr = static_cast<char *>(tensor->data) + offset;
|
538
|
+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
|
539
|
+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
|
540
|
+
}
|
541
|
+
|
542
|
+
static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
|
543
|
+
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
|
544
|
+
if (buffer == nullptr) {
|
545
|
+
return;
|
546
|
+
}
|
547
|
+
|
548
|
+
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
|
549
|
+
|
550
|
+
if (ctx != nullptr) {
|
551
|
+
for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
|
552
|
+
release_extra_gpu(extra);
|
553
|
+
}
|
554
|
+
ctx->tensor_extras.clear(); // reset the tensor_extras vector
|
555
|
+
}
|
556
|
+
}
|
557
|
+
|
447
558
|
static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
|
448
559
|
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
|
449
560
|
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
|
450
561
|
/* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
|
451
|
-
/* .memset_tensor = */
|
562
|
+
/* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
|
452
563
|
/* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
|
453
564
|
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
|
454
565
|
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
|
455
566
|
/* .clear = */ ggml_backend_sycl_buffer_clear,
|
456
|
-
/* .reset = */
|
567
|
+
/* .reset = */ ggml_backend_sycl_buffer_reset,
|
457
568
|
};
|
458
569
|
|
459
570
|
// sycl buffer type
|
@@ -534,12 +645,11 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
|
|
534
645
|
static std::mutex mutex;
|
535
646
|
std::lock_guard<std::mutex> lock(mutex);
|
536
647
|
|
537
|
-
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
|
538
648
|
|
539
649
|
auto dev_count = ggml_backend_sycl_get_device_count();
|
540
650
|
|
541
651
|
if (device>=dev_count or device<0) {
|
542
|
-
|
652
|
+
GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
|
543
653
|
device, dev_count-1);
|
544
654
|
GGML_ASSERT(device<dev_count);
|
545
655
|
}
|
@@ -562,12 +672,12 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
|
|
562
672
|
return &ggml_backend_sycl_buffer_types[device];
|
563
673
|
}
|
564
674
|
|
565
|
-
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
675
|
+
static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
566
676
|
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
|
567
677
|
|
568
678
|
int device = ctx->device;
|
569
679
|
if (device>=ggml_sycl_info().device_count or device<0) {
|
570
|
-
|
680
|
+
GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
|
571
681
|
device, ggml_sycl_info().device_count-1);
|
572
682
|
GGML_ASSERT(device<ggml_sycl_info().device_count);
|
573
683
|
}
|
@@ -664,32 +774,7 @@ struct ggml_backend_sycl_split_buffer_type_context {
|
|
664
774
|
struct ggml_backend_sycl_split_buffer_context {
|
665
775
|
~ggml_backend_sycl_split_buffer_context() try {
|
666
776
|
for (ggml_tensor_extra_gpu * extra : tensor_extras) {
|
667
|
-
|
668
|
-
for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
|
669
|
-
if (extra->events[i][is] != nullptr) {
|
670
|
-
/*
|
671
|
-
DPCT1009:206: SYCL uses exceptions to report errors and
|
672
|
-
does not use the error codes. The original code was
|
673
|
-
commented out and a warning string was inserted. You
|
674
|
-
need to rewrite this code.
|
675
|
-
*/
|
676
|
-
SYCL_CHECK(CHECK_TRY_ERROR(
|
677
|
-
dpct::destroy_event(extra->events[i][is])));
|
678
|
-
}
|
679
|
-
}
|
680
|
-
if (extra->data_device[i] != nullptr) {
|
681
|
-
/*
|
682
|
-
DPCT1009:207: SYCL uses exceptions to report errors and does
|
683
|
-
not use the error codes. The original code was commented out
|
684
|
-
and a warning string was inserted. You need to rewrite this
|
685
|
-
code.
|
686
|
-
*/
|
687
|
-
ggml_sycl_set_device(i);
|
688
|
-
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
|
689
|
-
extra->data_device[i], *(streams[i]))));
|
690
|
-
}
|
691
|
-
}
|
692
|
-
delete extra;
|
777
|
+
release_extra_gpu(extra, streams);
|
693
778
|
}
|
694
779
|
}
|
695
780
|
catch (sycl::exception const &exc) {
|
@@ -714,9 +799,11 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff
|
|
714
799
|
GGML_UNUSED(buffer);
|
715
800
|
}
|
716
801
|
|
717
|
-
static
|
802
|
+
static enum ggml_status
|
718
803
|
ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
719
804
|
ggml_tensor *tensor) try {
|
805
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
806
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
|
720
807
|
GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
|
721
808
|
|
722
809
|
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
|
@@ -727,7 +814,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
|
727
814
|
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
728
815
|
|
729
816
|
ctx->tensor_extras.push_back(extra);
|
730
|
-
|
817
|
+
ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
|
731
818
|
|
732
819
|
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
|
733
820
|
int64_t row_low, row_high;
|
@@ -746,7 +833,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
|
746
833
|
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
|
747
834
|
}
|
748
835
|
|
749
|
-
// FIXME: do not crash if
|
836
|
+
// FIXME: do not crash if SYCL Buffer alloc fails
|
750
837
|
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
|
751
838
|
ggml_sycl_set_device(i);
|
752
839
|
const queue_ptr stream = ctx->streams[i];
|
@@ -788,8 +875,8 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
|
788
875
|
CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
|
789
876
|
}
|
790
877
|
}
|
791
|
-
tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
|
792
878
|
tensor->extra = extra;
|
879
|
+
return GGML_STATUS_SUCCESS;
|
793
880
|
}
|
794
881
|
catch (sycl::exception const &exc) {
|
795
882
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
@@ -801,6 +888,9 @@ static void
|
|
801
888
|
ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
802
889
|
ggml_tensor *tensor, const void *data,
|
803
890
|
size_t offset, size_t size) try {
|
891
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
892
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
|
893
|
+
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
|
804
894
|
// split tensors must always be set in their entirety at once
|
805
895
|
GGML_ASSERT(offset == 0);
|
806
896
|
GGML_ASSERT(size == ggml_nbytes(tensor));
|
@@ -854,6 +944,9 @@ static void
|
|
854
944
|
ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
855
945
|
const ggml_tensor *tensor, void *data,
|
856
946
|
size_t offset, size_t size) try {
|
947
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
948
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
|
949
|
+
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
|
857
950
|
// split tensors must always be set in their entirety at once
|
858
951
|
GGML_ASSERT(offset == 0);
|
859
952
|
GGML_ASSERT(size == ggml_nbytes(tensor));
|
@@ -1178,6 +1271,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
|
|
1178
1271
|
}
|
1179
1272
|
};
|
1180
1273
|
|
1274
|
+
struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
1275
|
+
queue_ptr qptr;
|
1276
|
+
int device;
|
1277
|
+
|
1278
|
+
inline static int counter{ 0 };
|
1279
|
+
|
1280
|
+
struct ggml_sycl_buffer {
|
1281
|
+
void * ptr = nullptr;
|
1282
|
+
size_t size = 0;
|
1283
|
+
};
|
1284
|
+
|
1285
|
+
// Set arbitrarly to 64
|
1286
|
+
static constexpr int MAX_POOL_SIZE{ 64 };
|
1287
|
+
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
|
1288
|
+
size_t pool_size = 0;
|
1289
|
+
|
1290
|
+
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
|
1291
|
+
|
1292
|
+
~ggml_sycl_pool_host() {
|
1293
|
+
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
|
1294
|
+
ggml_sycl_buffer & b = buffer_pool[i];
|
1295
|
+
if (b.ptr != nullptr) {
|
1296
|
+
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
|
1297
|
+
b.ptr = nullptr;
|
1298
|
+
pool_size -= b.size;
|
1299
|
+
b.size = 0;
|
1300
|
+
}
|
1301
|
+
}
|
1302
|
+
counter = 0;
|
1303
|
+
}
|
1304
|
+
|
1305
|
+
void * alloc(size_t size, size_t * actual_size) override {
|
1306
|
+
if (counter == MAX_POOL_SIZE) {
|
1307
|
+
ggml_sycl_buffer b = buffer_pool[0];
|
1308
|
+
void * ptr = b.ptr;
|
1309
|
+
*actual_size = b.size;
|
1310
|
+
counter = 1;
|
1311
|
+
return ptr;
|
1312
|
+
}
|
1313
|
+
ggml_sycl_buffer & b = buffer_pool[counter];
|
1314
|
+
|
1315
|
+
if (b.ptr == nullptr) {
|
1316
|
+
void * ptr;
|
1317
|
+
|
1318
|
+
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
|
1319
|
+
if (!ptr) {
|
1320
|
+
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
|
1321
|
+
return nullptr;
|
1322
|
+
}
|
1323
|
+
pool_size += size;
|
1324
|
+
*actual_size = size;
|
1325
|
+
counter = counter + 1;
|
1326
|
+
return ptr;
|
1327
|
+
} else {
|
1328
|
+
++counter;
|
1329
|
+
b.size = size;
|
1330
|
+
return b.ptr;
|
1331
|
+
}
|
1332
|
+
}
|
1333
|
+
|
1334
|
+
void free(void * ptr, size_t size) override {
|
1335
|
+
// if the pool is not completed add the pointer to it in place of the first nullptr found.
|
1336
|
+
// Otherwise do nothing, pointers will be freed once the pool is deallocated.
|
1337
|
+
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
|
1338
|
+
ggml_sycl_buffer & b = buffer_pool[i];
|
1339
|
+
if (b.ptr == nullptr) {
|
1340
|
+
b.ptr = ptr;
|
1341
|
+
b.size = size;
|
1342
|
+
return;
|
1343
|
+
}
|
1344
|
+
}
|
1345
|
+
}
|
1346
|
+
};
|
1347
|
+
|
1348
|
+
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
|
1349
|
+
// return pool for the host to speed up memory management
|
1350
|
+
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
|
1351
|
+
}
|
1352
|
+
|
1181
1353
|
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
|
1182
1354
|
// TBD: NO VMM support
|
1183
1355
|
// if (ggml_sycl_info().devices[device].vmm) {
|
@@ -1190,9 +1362,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
|
|
1190
1362
|
// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
|
1191
1363
|
|
1192
1364
|
/// kernels
|
1193
|
-
|
1194
|
-
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
1195
|
-
typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
1196
1365
|
typedef void (*ggml_sycl_op_mul_mat_t)(
|
1197
1366
|
ggml_backend_sycl_context & ctx,
|
1198
1367
|
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
@@ -1264,81 +1433,57 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
|
|
1264
1433
|
reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
|
1265
1434
|
}
|
1266
1435
|
|
1267
|
-
template<int
|
1268
|
-
static void
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
size_t s10, size_t s11, size_t s12,
|
1275
|
-
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
1276
|
-
|
1277
|
-
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
1278
|
-
item_ct1.get_local_id(2)) *
|
1279
|
-
2;
|
1280
|
-
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
1281
|
-
item_ct1.get_local_id(1);
|
1282
|
-
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
1283
|
-
item_ct1.get_local_id(0)) /
|
1284
|
-
ne12;
|
1285
|
-
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
1286
|
-
item_ct1.get_local_id(0)) %
|
1287
|
-
ne12;
|
1288
|
-
|
1289
|
-
if (i00 >= ne00) {
|
1290
|
-
return;
|
1291
|
-
}
|
1436
|
+
template <int ElementsPerWI>
|
1437
|
+
static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
|
1438
|
+
const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
|
1439
|
+
/*
|
1440
|
+
Quantizes and reorders the resultant q8 tensor in a per row fashion
|
1441
|
+
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
|
1442
|
+
*/
|
1292
1443
|
|
1293
|
-
|
1444
|
+
auto subgroup_id = it.get_group(0);
|
1445
|
+
auto wi_id = it.get_local_id(0);
|
1294
1446
|
|
1295
|
-
|
1296
|
-
|
1447
|
+
const int num_blocks_per_row = kx / QK8_1;
|
1448
|
+
auto row = subgroup_id / num_blocks_per_row;
|
1449
|
+
auto col = subgroup_id % num_blocks_per_row;
|
1297
1450
|
|
1298
|
-
|
1299
|
-
|
1300
|
-
const int iybs = i00 - i00%qk; // dst block start index
|
1301
|
-
const int y_offset = qr == 1 ? 1 : qk/2;
|
1451
|
+
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
|
1452
|
+
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
|
1302
1453
|
|
1303
|
-
|
1304
|
-
|
1305
|
-
dequantize_kernel(src0_row, ib, iqs, v);
|
1454
|
+
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
|
1455
|
+
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
|
1306
1456
|
|
1307
|
-
|
1308
|
-
|
1309
|
-
}
|
1457
|
+
sycl::vec<float, ElementsPerWI> wi_f32_vals;
|
1458
|
+
sycl::vec<int8_t, ElementsPerWI> quantized_values;
|
1310
1459
|
|
1311
|
-
|
1312
|
-
|
1313
|
-
const src0_t * src0, const int32_t * src1, dst_t * dst,
|
1314
|
-
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
1315
|
-
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
1316
|
-
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
1317
|
-
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
1318
|
-
size_t s10, size_t s11, size_t s12,
|
1319
|
-
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
1460
|
+
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
|
1461
|
+
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
|
1320
1462
|
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
item_ct1.get_local_id(0)) %
|
1330
|
-
ne12;
|
1331
|
-
|
1332
|
-
if (i00 >= ne00) {
|
1333
|
-
return;
|
1463
|
+
float sum = 0.0f;
|
1464
|
+
float amax = 0.0f;
|
1465
|
+
|
1466
|
+
#pragma unroll(ElementsPerWI)
|
1467
|
+
for (int i = 0; i < ElementsPerWI; i++) {
|
1468
|
+
sum += wi_f32_vals[i];
|
1469
|
+
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
|
1470
|
+
quantized_values[i] = 0;
|
1334
1471
|
}
|
1472
|
+
sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
|
1473
|
+
amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
|
1474
|
+
float d = amax == 0 ? 1 : amax / 127;
|
1335
1475
|
|
1336
|
-
|
1476
|
+
#pragma unroll(ElementsPerWI)
|
1477
|
+
for (int i = 0; i < ElementsPerWI; i++) {
|
1478
|
+
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
|
1479
|
+
}
|
1337
1480
|
|
1338
|
-
|
1339
|
-
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
1481
|
+
d = amax == 0 ? 0 : d;
|
1340
1482
|
|
1341
|
-
|
1483
|
+
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
|
1484
|
+
if (wi_id == 0) {
|
1485
|
+
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
|
1486
|
+
}
|
1342
1487
|
}
|
1343
1488
|
|
1344
1489
|
static void mul_mat_p021_f16_f32(
|
@@ -1451,193 +1596,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
1451
1596
|
}
|
1452
1597
|
}
|
1453
1598
|
|
1454
|
-
static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
1455
|
-
const float * xi = (const float *) cxi;
|
1456
|
-
float * dsti = (float *) cdsti;
|
1457
|
-
|
1458
|
-
*dsti = *xi;
|
1459
|
-
}
|
1460
|
-
|
1461
|
-
static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
1462
|
-
const float * xi = (const float *) cxi;
|
1463
|
-
sycl::half *dsti = (sycl::half *)cdsti;
|
1464
|
-
|
1465
|
-
*dsti = sycl::vec<float, 1>(*xi)
|
1466
|
-
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
1467
|
-
}
|
1468
|
-
|
1469
|
-
static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
1470
|
-
const sycl::half *xi = (const sycl::half *)cxi;
|
1471
|
-
sycl::half *dsti = (sycl::half *)cdsti;
|
1472
|
-
|
1473
|
-
*dsti = *xi;
|
1474
|
-
}
|
1475
|
-
|
1476
|
-
static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
1477
|
-
const sycl::half *xi = (const sycl::half *)cxi;
|
1478
|
-
float * dsti = (float *) cdsti;
|
1479
|
-
|
1480
|
-
*dsti = *xi;
|
1481
|
-
}
|
1482
|
-
|
1483
|
-
static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
|
1484
|
-
const int16_t *xi = (const int16_t *)cxi;
|
1485
|
-
int16_t *dsti = (int16_t *)cdsti;
|
1486
|
-
|
1487
|
-
*dsti = *xi;
|
1488
|
-
}
|
1489
|
-
|
1490
|
-
static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
|
1491
|
-
const int32_t *xi = (const int32_t *)cxi;
|
1492
|
-
int32_t *dsti = (int32_t *)cdsti;
|
1493
|
-
|
1494
|
-
*dsti = *xi;
|
1495
|
-
}
|
1496
|
-
|
1497
|
-
template <cpy_kernel_t cpy_1>
|
1498
|
-
static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
1499
|
-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
1500
|
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
1501
|
-
const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
|
1502
|
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
1503
|
-
item_ct1.get_local_id(2);
|
1504
|
-
|
1505
|
-
if (i >= ne) {
|
1506
|
-
return;
|
1507
|
-
}
|
1508
|
-
|
1509
|
-
// determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
1510
|
-
// then combine those indices with the corresponding byte offsets to get the total offsets
|
1511
|
-
const int i03 = i/(ne00 * ne01 * ne02);
|
1512
|
-
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
1513
|
-
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
1514
|
-
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
1515
|
-
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
1516
|
-
|
1517
|
-
const int i13 = i/(ne10 * ne11 * ne12);
|
1518
|
-
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
1519
|
-
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
1520
|
-
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
1521
|
-
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
1522
|
-
|
1523
|
-
cpy_1(cx + x_offset, cdst + dst_offset);
|
1524
|
-
}
|
1525
|
-
|
1526
|
-
static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
1527
|
-
const float * xi = (const float *) cxi;
|
1528
|
-
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
1529
|
-
|
1530
|
-
float amax = 0.0f; // absolute max
|
1531
|
-
|
1532
|
-
for (int j = 0; j < QK8_0; j++) {
|
1533
|
-
const float v = xi[j];
|
1534
|
-
amax = sycl::fmax(amax, sycl::fabs((float)v));
|
1535
|
-
}
|
1536
|
-
|
1537
|
-
const float d = amax / ((1 << 7) - 1);
|
1538
|
-
const float id = d ? 1.0f/d : 0.0f;
|
1539
|
-
|
1540
|
-
dsti->d = d;
|
1541
|
-
|
1542
|
-
for (int j = 0; j < QK8_0; ++j) {
|
1543
|
-
const float x0 = xi[j]*id;
|
1544
|
-
|
1545
|
-
dsti->qs[j] = sycl::round((float)x0);
|
1546
|
-
}
|
1547
|
-
}
|
1548
|
-
|
1549
|
-
static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
1550
|
-
const float * xi = (const float *) cxi;
|
1551
|
-
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
1552
|
-
|
1553
|
-
float amax = 0.0f;
|
1554
|
-
float vmax = 0.0f;
|
1555
|
-
|
1556
|
-
for (int j = 0; j < QK4_0; ++j) {
|
1557
|
-
const float v = xi[j];
|
1558
|
-
if (amax < sycl::fabs((float)v)) {
|
1559
|
-
amax = sycl::fabs((float)v);
|
1560
|
-
vmax = v;
|
1561
|
-
}
|
1562
|
-
}
|
1563
|
-
|
1564
|
-
const float d = vmax / -8;
|
1565
|
-
const float id = d ? 1.0f/d : 0.0f;
|
1566
|
-
|
1567
|
-
dsti->d = d;
|
1568
|
-
|
1569
|
-
for (int j = 0; j < QK4_0/2; ++j) {
|
1570
|
-
const float x0 = xi[0 + j]*id;
|
1571
|
-
const float x1 = xi[QK4_0/2 + j]*id;
|
1572
|
-
|
1573
|
-
const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
|
1574
|
-
const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
|
1575
|
-
|
1576
|
-
dsti->qs[j] = xi0;
|
1577
|
-
dsti->qs[j] |= xi1 << 4;
|
1578
|
-
}
|
1579
|
-
}
|
1580
|
-
|
1581
|
-
static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
1582
|
-
const float * xi = (const float *) cxi;
|
1583
|
-
block_q4_1 * dsti = (block_q4_1 *) cdsti;
|
1584
|
-
|
1585
|
-
float vmin = FLT_MAX;
|
1586
|
-
float vmax = -FLT_MAX;
|
1587
|
-
|
1588
|
-
for (int j = 0; j < QK4_1; ++j) {
|
1589
|
-
const float v = xi[j];
|
1590
|
-
|
1591
|
-
if (v < vmin) vmin = v;
|
1592
|
-
if (v > vmax) vmax = v;
|
1593
|
-
}
|
1594
|
-
|
1595
|
-
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
1596
|
-
const float id = d ? 1.0f/d : 0.0f;
|
1597
|
-
|
1598
|
-
dsti->dm.x() = d;
|
1599
|
-
dsti->dm.y() = vmin;
|
1600
|
-
|
1601
|
-
for (int j = 0; j < QK4_1/2; ++j) {
|
1602
|
-
const float x0 = (xi[0 + j] - vmin)*id;
|
1603
|
-
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
|
1604
|
-
|
1605
|
-
const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
|
1606
|
-
const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
|
1607
|
-
|
1608
|
-
dsti->qs[j] = xi0;
|
1609
|
-
dsti->qs[j] |= xi1 << 4;
|
1610
|
-
}
|
1611
|
-
}
|
1612
|
-
|
1613
|
-
template <cpy_kernel_t cpy_blck, int qk>
|
1614
|
-
static void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
1615
|
-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
1616
|
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
1617
|
-
const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
|
1618
|
-
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
1619
|
-
item_ct1.get_local_id(2)) *
|
1620
|
-
qk;
|
1621
|
-
|
1622
|
-
if (i >= ne) {
|
1623
|
-
return;
|
1624
|
-
}
|
1625
|
-
|
1626
|
-
const int i03 = i/(ne00 * ne01 * ne02);
|
1627
|
-
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
1628
|
-
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
1629
|
-
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
1630
|
-
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
1631
|
-
|
1632
|
-
const int i13 = i/(ne10 * ne11 * ne12);
|
1633
|
-
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
1634
|
-
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
1635
|
-
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
1636
|
-
const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
1637
|
-
|
1638
|
-
cpy_blck(cx + x_offset, cdst + dst_offset);
|
1639
|
-
}
|
1640
|
-
|
1641
1599
|
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
1642
1600
|
const sycl::nd_item<3> &item_ct1) {
|
1643
1601
|
const int row = item_ct1.get_group(1);
|
@@ -1749,17 +1707,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
|
|
1749
1707
|
dst[i] = scale * x[i];
|
1750
1708
|
}
|
1751
1709
|
|
1752
|
-
static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
|
1753
|
-
const sycl::nd_item<3> &item_ct1) {
|
1754
|
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
1755
|
-
item_ct1.get_local_id(2);
|
1756
|
-
|
1757
|
-
if (i >= k) {
|
1758
|
-
return;
|
1759
|
-
}
|
1760
|
-
|
1761
|
-
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
1762
|
-
}
|
1763
1710
|
|
1764
1711
|
template <typename Ti, typename To>
|
1765
1712
|
static void pool2d_nchw_kernel(
|
@@ -1823,98 +1770,30 @@ static void pool2d_nchw_kernel(
|
|
1823
1770
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
1824
1771
|
}
|
1825
1772
|
|
1826
|
-
|
1827
|
-
|
1828
|
-
|
1829
|
-
|
1830
|
-
|
1831
|
-
|
1832
|
-
|
1833
|
-
|
1834
|
-
|
1835
|
-
|
1836
|
-
|
1837
|
-
|
1838
|
-
|
1839
|
-
|
1840
|
-
|
1841
|
-
|
1842
|
-
|
1843
|
-
|
1844
|
-
const size_t s10 = nb10 / ggml_element_size(src1);
|
1845
|
-
const size_t s11 = nb11 / ggml_element_size(src1);
|
1846
|
-
const size_t s12 = nb12 / ggml_element_size(src1);
|
1847
|
-
//const size_t s13 = nb13 / ggml_element_size(src1);
|
1848
|
-
|
1849
|
-
GGML_ASSERT(ne00 % 2 == 0);
|
1850
|
-
|
1851
|
-
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1852
|
-
[=](sycl::nd_item<3> item_ct1) {
|
1853
|
-
k_get_rows<qk, qr, dq>(
|
1854
|
-
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
1855
|
-
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
1856
|
-
});
|
1857
|
-
|
1858
|
-
GGML_UNUSED(dst);
|
1859
|
-
GGML_UNUSED(ctx);
|
1860
|
-
}
|
1861
|
-
|
1862
|
-
template <typename src0_t>
|
1863
|
-
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
1864
|
-
const ggml_tensor *src1, ggml_tensor *dst,
|
1865
|
-
const src0_t *src0_dd, const int32_t *src1_dd,
|
1866
|
-
float *dst_dd, queue_ptr stream) {
|
1867
|
-
|
1868
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
1869
|
-
|
1870
|
-
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
1871
|
-
const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
|
1872
|
-
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
1873
|
-
|
1874
|
-
// strides in elements
|
1875
|
-
//const size_t s0 = nb0 / ggml_element_size(dst);
|
1876
|
-
const size_t s1 = nb1 / ggml_element_size(dst);
|
1877
|
-
const size_t s2 = nb2 / ggml_element_size(dst);
|
1878
|
-
const size_t s3 = nb3 / ggml_element_size(dst);
|
1879
|
-
|
1880
|
-
const size_t s10 = nb10 / ggml_element_size(src1);
|
1881
|
-
const size_t s11 = nb11 / ggml_element_size(src1);
|
1882
|
-
const size_t s12 = nb12 / ggml_element_size(src1);
|
1883
|
-
//const size_t s13 = nb13 / ggml_element_size(src1);
|
1884
|
-
|
1885
|
-
{
|
1886
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
1887
|
-
{sycl::aspect::fp16});
|
1888
|
-
|
1889
|
-
stream->parallel_for(
|
1890
|
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1891
|
-
[=](sycl::nd_item<3> item_ct1) {
|
1892
|
-
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
1893
|
-
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
1894
|
-
});
|
1895
|
-
}
|
1896
|
-
|
1897
|
-
GGML_UNUSED(dst);
|
1898
|
-
GGML_UNUSED(ctx);
|
1899
|
-
}
|
1900
|
-
|
1901
|
-
static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
1902
|
-
const int ky, const int kx_padded,
|
1903
|
-
queue_ptr stream) {
|
1904
|
-
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
|
1905
|
-
const sycl::range<3> num_blocks(1, ky, block_num_x);
|
1906
|
-
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
|
1907
|
-
static_assert(QK8_1 % WARP_SIZE == 0);
|
1908
|
-
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
|
1909
|
-
{
|
1910
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
1911
|
-
{sycl::aspect::fp16});
|
1773
|
+
static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
|
1774
|
+
bool reorder_q8_tensor, queue_ptr stream) {
|
1775
|
+
if (reorder_q8_tensor) {
|
1776
|
+
auto local_range = std::size_t(WARP_SIZE);
|
1777
|
+
auto num_quant_blocks = ky * (kx / QK8_1);
|
1778
|
+
auto global_range = num_quant_blocks * local_range;
|
1779
|
+
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
|
1780
|
+
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
1781
|
+
quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
|
1782
|
+
});
|
1783
|
+
} else {
|
1784
|
+
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
|
1785
|
+
const sycl::range<3> num_blocks(1, ky, block_num_x);
|
1786
|
+
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
|
1787
|
+
static_assert(QK8_1 % WARP_SIZE == 0);
|
1788
|
+
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
|
1789
|
+
{
|
1790
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
1912
1791
|
|
1913
|
-
|
1914
|
-
|
1915
|
-
|
1916
|
-
|
1917
|
-
|
1792
|
+
stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
|
1793
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
1794
|
+
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
1795
|
+
});
|
1796
|
+
}
|
1918
1797
|
}
|
1919
1798
|
}
|
1920
1799
|
|
@@ -1933,7 +1812,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
|
1933
1812
|
|
1934
1813
|
stream->parallel_for(
|
1935
1814
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1936
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
1815
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
1937
1816
|
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
|
1938
1817
|
nchannels_y, item_ct1);
|
1939
1818
|
});
|
@@ -1953,7 +1832,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
|
1953
1832
|
|
1954
1833
|
stream->parallel_for(
|
1955
1834
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1956
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
1835
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
1957
1836
|
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
1958
1837
|
row_stride_x, channel_stride_x,
|
1959
1838
|
nchannels_y / nchannels_x, item_ct1);
|
@@ -1961,231 +1840,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
|
1961
1840
|
}
|
1962
1841
|
}
|
1963
1842
|
|
1964
|
-
static void
|
1965
|
-
ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
|
1966
|
-
const int ne01, const int ne02, const int nb00,
|
1967
|
-
const int nb01, const int nb02, const int nb03,
|
1968
|
-
const int ne10, const int ne11, const int ne12,
|
1969
|
-
const int nb10, const int nb11, const int nb12,
|
1970
|
-
const int nb13, queue_ptr stream) {
|
1971
|
-
|
1972
|
-
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
1973
|
-
{
|
1974
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
1975
|
-
{sycl::aspect::fp16});
|
1976
1843
|
|
1977
|
-
stream->parallel_for(
|
1978
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
1979
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
1980
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
1981
|
-
[=](sycl::nd_item<3> item_ct1) {
|
1982
|
-
cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
|
1983
|
-
nb01, nb02, nb03, ne10, ne11, ne12,
|
1984
|
-
nb10, nb11, nb12, nb13, item_ct1);
|
1985
|
-
});
|
1986
|
-
}
|
1987
|
-
}
|
1988
|
-
|
1989
|
-
static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
|
1990
|
-
const int ne00, const int ne01,
|
1991
|
-
const int ne02, const int nb00,
|
1992
|
-
const int nb01, const int nb02,
|
1993
|
-
const int nb03, const int ne10,
|
1994
|
-
const int ne11, const int ne12,
|
1995
|
-
const int nb10, const int nb11,
|
1996
|
-
const int nb12, const int nb13,
|
1997
|
-
queue_ptr stream) {
|
1998
|
-
|
1999
|
-
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
2000
|
-
{
|
2001
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
2002
|
-
{sycl::aspect::fp16});
|
2003
|
-
|
2004
|
-
stream->parallel_for(
|
2005
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
2006
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
2007
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
2008
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2009
|
-
cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
2010
|
-
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
2011
|
-
item_ct1);
|
2012
|
-
});
|
2013
|
-
}
|
2014
|
-
}
|
2015
|
-
|
2016
|
-
static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
|
2017
|
-
const int ne00, const int ne01,
|
2018
|
-
const int ne02, const int nb00,
|
2019
|
-
const int nb01, const int nb02,
|
2020
|
-
const int nb03, const int ne10,
|
2021
|
-
const int ne11, const int ne12,
|
2022
|
-
const int nb10, const int nb11,
|
2023
|
-
const int nb12, const int nb13,
|
2024
|
-
queue_ptr stream) {
|
2025
|
-
|
2026
|
-
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
2027
|
-
{
|
2028
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
2029
|
-
{sycl::aspect::fp16});
|
2030
|
-
|
2031
|
-
stream->parallel_for(
|
2032
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
2033
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
2034
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
2035
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2036
|
-
cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
2037
|
-
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
2038
|
-
item_ct1);
|
2039
|
-
});
|
2040
|
-
}
|
2041
|
-
}
|
2042
|
-
|
2043
|
-
static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
|
2044
|
-
const int ne00, const int ne01,
|
2045
|
-
const int ne02, const int nb00,
|
2046
|
-
const int nb01, const int nb02,
|
2047
|
-
const int nb03, const int ne10,
|
2048
|
-
const int ne11, const int ne12,
|
2049
|
-
const int nb10, const int nb11,
|
2050
|
-
const int nb12, const int nb13,
|
2051
|
-
queue_ptr stream) {
|
2052
|
-
|
2053
|
-
GGML_ASSERT(ne % QK8_0 == 0);
|
2054
|
-
const int num_blocks = ne / QK8_0;
|
2055
|
-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
|
2056
|
-
sycl::range<3>(1, 1, 1)),
|
2057
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2058
|
-
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
|
2059
|
-
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
2060
|
-
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
2061
|
-
item_ct1);
|
2062
|
-
});
|
2063
|
-
}
|
2064
|
-
|
2065
|
-
static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
|
2066
|
-
const int ne00, const int ne01,
|
2067
|
-
const int ne02, const int nb00,
|
2068
|
-
const int nb01, const int nb02,
|
2069
|
-
const int nb03, const int ne10,
|
2070
|
-
const int ne11, const int ne12,
|
2071
|
-
const int nb10, const int nb11,
|
2072
|
-
const int nb12, const int nb13,
|
2073
|
-
queue_ptr stream) {
|
2074
|
-
|
2075
|
-
GGML_ASSERT(ne % QK4_0 == 0);
|
2076
|
-
const int num_blocks = ne / QK4_0;
|
2077
|
-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
|
2078
|
-
sycl::range<3>(1, 1, 1)),
|
2079
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2080
|
-
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
|
2081
|
-
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
2082
|
-
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
2083
|
-
item_ct1);
|
2084
|
-
});
|
2085
|
-
}
|
2086
|
-
|
2087
|
-
static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
|
2088
|
-
const int ne00, const int ne01,
|
2089
|
-
const int ne02, const int nb00,
|
2090
|
-
const int nb01, const int nb02,
|
2091
|
-
const int nb03, const int ne10,
|
2092
|
-
const int ne11, const int ne12,
|
2093
|
-
const int nb10, const int nb11,
|
2094
|
-
const int nb12, const int nb13,
|
2095
|
-
queue_ptr stream) {
|
2096
|
-
|
2097
|
-
GGML_ASSERT(ne % QK4_1 == 0);
|
2098
|
-
const int num_blocks = ne / QK4_1;
|
2099
|
-
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
|
2100
|
-
sycl::range<3>(1, 1, 1)),
|
2101
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2102
|
-
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
|
2103
|
-
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
2104
|
-
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
2105
|
-
item_ct1);
|
2106
|
-
});
|
2107
|
-
}
|
2108
|
-
|
2109
|
-
static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
|
2110
|
-
const int ne00, const int ne01,
|
2111
|
-
const int ne02, const int nb00,
|
2112
|
-
const int nb01, const int nb02,
|
2113
|
-
const int nb03, const int ne10,
|
2114
|
-
const int ne11, const int ne12,
|
2115
|
-
const int nb10, const int nb11,
|
2116
|
-
const int nb12, const int nb13,
|
2117
|
-
queue_ptr stream) {
|
2118
|
-
|
2119
|
-
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
2120
|
-
{
|
2121
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
2122
|
-
{sycl::aspect::fp16});
|
2123
|
-
|
2124
|
-
stream->parallel_for(
|
2125
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
2126
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
2127
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
2128
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2129
|
-
cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
2130
|
-
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
2131
|
-
item_ct1);
|
2132
|
-
});
|
2133
|
-
}
|
2134
|
-
}
|
2135
|
-
|
2136
|
-
static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
|
2137
|
-
const int ne00, const int ne01,
|
2138
|
-
const int ne02, const int nb00,
|
2139
|
-
const int nb01, const int nb02,
|
2140
|
-
const int nb03, const int ne10,
|
2141
|
-
const int ne11, const int ne12,
|
2142
|
-
const int nb10, const int nb11,
|
2143
|
-
const int nb12, const int nb13,
|
2144
|
-
queue_ptr stream) {
|
2145
|
-
|
2146
|
-
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
2147
|
-
{
|
2148
|
-
// dpct::has_capability_or_fail(stream->get_device(),
|
2149
|
-
// {sycl::aspect::fp16});
|
2150
|
-
|
2151
|
-
stream->parallel_for(
|
2152
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
2153
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
2154
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
2155
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2156
|
-
cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
2157
|
-
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
2158
|
-
item_ct1);
|
2159
|
-
});
|
2160
|
-
}
|
2161
|
-
}
|
2162
|
-
|
2163
|
-
static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
|
2164
|
-
const int ne00, const int ne01,
|
2165
|
-
const int ne02, const int nb00,
|
2166
|
-
const int nb01, const int nb02,
|
2167
|
-
const int nb03, const int ne10,
|
2168
|
-
const int ne11, const int ne12,
|
2169
|
-
const int nb10, const int nb11,
|
2170
|
-
const int nb12, const int nb13,
|
2171
|
-
queue_ptr stream) {
|
2172
|
-
|
2173
|
-
const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
|
2174
|
-
{
|
2175
|
-
// dpct::has_capability_or_fail(stream->get_device(),
|
2176
|
-
// {sycl::aspect::fp16});
|
2177
|
-
|
2178
|
-
stream->parallel_for(
|
2179
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
2180
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
2181
|
-
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
2182
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2183
|
-
cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
2184
|
-
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
2185
|
-
item_ct1);
|
2186
|
-
});
|
2187
|
-
}
|
2188
|
-
}
|
2189
1844
|
|
2190
1845
|
static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
2191
1846
|
const int k, queue_ptr stream) {
|
@@ -2199,18 +1854,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
|
2199
1854
|
});
|
2200
1855
|
}
|
2201
1856
|
|
2202
|
-
static void clamp_f32_sycl(const float *x, float *dst, const float min,
|
2203
|
-
const float max, const int k,
|
2204
|
-
queue_ptr stream) {
|
2205
|
-
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
|
2206
|
-
stream->parallel_for(
|
2207
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
2208
|
-
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
|
2209
|
-
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
|
2210
|
-
[=](sycl::nd_item<3> item_ct1) {
|
2211
|
-
clamp_f32(x, dst, min, max, k, item_ct1);
|
2212
|
-
});
|
2213
|
-
}
|
2214
1857
|
|
2215
1858
|
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
2216
1859
|
const int nrows, queue_ptr stream) {
|
@@ -2218,7 +1861,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
|
2218
1861
|
const sycl::range<3> block_nums(1, nrows, 1);
|
2219
1862
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2220
1863
|
[=](sycl::nd_item<3> item_ct1)
|
2221
|
-
[[
|
1864
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
2222
1865
|
k_sum_rows_f32(x, dst, ncols, item_ct1);
|
2223
1866
|
});
|
2224
1867
|
}
|
@@ -2242,13 +1885,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
2242
1885
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
2243
1886
|
|
2244
1887
|
if (order == GGML_SORT_ORDER_ASC) {
|
2245
|
-
stream
|
1888
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
2246
1889
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
2247
1890
|
sycl::range<1>(shared_mem), cgh);
|
2248
1891
|
|
2249
|
-
|
2250
|
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2251
|
-
[=](sycl::nd_item<3> item_ct1) {
|
1892
|
+
sycl_parallel_for(
|
1893
|
+
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
2252
1894
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
2253
1895
|
x, dst, ncols, ncols_pad, item_ct1,
|
2254
1896
|
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
@@ -2256,13 +1898,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
2256
1898
|
});
|
2257
1899
|
});
|
2258
1900
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
2259
|
-
stream
|
1901
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
2260
1902
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
2261
1903
|
sycl::range<1>(shared_mem), cgh);
|
2262
1904
|
|
2263
|
-
|
2264
|
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2265
|
-
[=](sycl::nd_item<3> item_ct1) {
|
1905
|
+
sycl_parallel_for(
|
1906
|
+
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
2266
1907
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
2267
1908
|
x, dst, ncols, ncols_pad, item_ct1,
|
2268
1909
|
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
@@ -2280,50 +1921,47 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
2280
1921
|
const sycl::range<3> block_nums(1, nrows, 1);
|
2281
1922
|
const size_t shared_mem = 256 * sizeof(float);
|
2282
1923
|
|
2283
|
-
stream
|
1924
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
2284
1925
|
sycl::local_accessor<float, 1> shared_data(
|
2285
1926
|
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
2286
1927
|
sycl::local_accessor<int, 1> shared_indices(
|
2287
1928
|
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
2288
1929
|
|
2289
|
-
cgh
|
2290
|
-
|
2291
|
-
|
2292
|
-
const int tid = item_ct1.get_local_id(2);
|
2293
|
-
const int row = item_ct1.get_global_id(1);
|
2294
|
-
|
2295
|
-
float max_val = -INFINITY;
|
2296
|
-
int max_idx = -1;
|
2297
|
-
|
2298
|
-
for (int col = tid; col < ncols; col += 256) {
|
2299
|
-
float val = x[row * ncols + col];
|
2300
|
-
if (val > max_val) {
|
2301
|
-
max_val = val;
|
2302
|
-
max_idx = col;
|
2303
|
-
}
|
2304
|
-
}
|
1930
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
1931
|
+
const int tid = item_ct1.get_local_id(2);
|
1932
|
+
const int row = item_ct1.get_global_id(1);
|
2305
1933
|
|
2306
|
-
|
2307
|
-
|
2308
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
1934
|
+
float max_val = -INFINITY;
|
1935
|
+
int max_idx = -1;
|
2309
1936
|
|
2310
|
-
|
2311
|
-
|
2312
|
-
|
2313
|
-
|
2314
|
-
|
2315
|
-
shared_data[tid] = val2;
|
2316
|
-
shared_indices[tid] = shared_indices[tid + stride];
|
2317
|
-
}
|
2318
|
-
}
|
2319
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
1937
|
+
for (int col = tid; col < ncols; col += 256) {
|
1938
|
+
float val = x[row * ncols + col];
|
1939
|
+
if (val > max_val) {
|
1940
|
+
max_val = val;
|
1941
|
+
max_idx = col;
|
2320
1942
|
}
|
1943
|
+
}
|
2321
1944
|
|
1945
|
+
shared_data[tid] = max_val;
|
1946
|
+
shared_indices[tid] = max_idx;
|
1947
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
2322
1948
|
|
2323
|
-
|
2324
|
-
|
1949
|
+
for (int stride = 256 / 2; stride > 0; stride >>= 1) {
|
1950
|
+
if (tid < stride) {
|
1951
|
+
float val1 = shared_data[tid];
|
1952
|
+
float val2 = shared_data[tid + stride];
|
1953
|
+
if (val2 > val1) {
|
1954
|
+
shared_data[tid] = val2;
|
1955
|
+
shared_indices[tid] = shared_indices[tid + stride];
|
1956
|
+
}
|
2325
1957
|
}
|
2326
|
-
|
1958
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
1959
|
+
}
|
1960
|
+
|
1961
|
+
if (tid == 0) {
|
1962
|
+
dst[row] = shared_indices[0];
|
1963
|
+
}
|
1964
|
+
});
|
2327
1965
|
});
|
2328
1966
|
}
|
2329
1967
|
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
@@ -2349,12 +1987,22 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
|
|
2349
1987
|
|
2350
1988
|
dpct::memcpy_direction kind;
|
2351
1989
|
char * src_ptr;
|
2352
|
-
if (src->
|
1990
|
+
if (ggml_backend_buffer_is_host(src->buffer)) {
|
2353
1991
|
kind = dpct::host_to_device;
|
1992
|
+
//GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__);
|
2354
1993
|
src_ptr = (char *) src->data;
|
2355
1994
|
// GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
|
2356
|
-
} else if (src->
|
2357
|
-
|
1995
|
+
} else if (ggml_backend_buffer_is_sycl(src->buffer)) {
|
1996
|
+
// If buffer is a SYCL buffer
|
1997
|
+
//GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
|
1998
|
+
kind = dpct::device_to_device;
|
1999
|
+
src_ptr = (char *) src->data;
|
2000
|
+
} else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
|
2001
|
+
/*
|
2002
|
+
If buffer is a SYCL split buffer
|
2003
|
+
*/
|
2004
|
+
//GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
|
2005
|
+
GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
|
2358
2006
|
kind = dpct::device_to_device;
|
2359
2007
|
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
|
2360
2008
|
int id;
|
@@ -2411,65 +2059,6 @@ catch (sycl::exception const &exc) {
|
|
2411
2059
|
std::exit(1);
|
2412
2060
|
}
|
2413
2061
|
|
2414
|
-
static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
2415
|
-
const ggml_tensor *src1, ggml_tensor *dst,
|
2416
|
-
const float *src0_d, const float *src1_d,
|
2417
|
-
float *dst_d, const queue_ptr &stream) {
|
2418
|
-
|
2419
|
-
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
2420
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
2421
|
-
|
2422
|
-
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
2423
|
-
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
2424
|
-
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
2425
|
-
|
2426
|
-
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
2427
|
-
|
2428
|
-
switch (src0->type) {
|
2429
|
-
case GGML_TYPE_F16:
|
2430
|
-
get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
|
2431
|
-
src1_i32, dst_d, stream);
|
2432
|
-
break;
|
2433
|
-
case GGML_TYPE_F32:
|
2434
|
-
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
2435
|
-
break;
|
2436
|
-
case GGML_TYPE_Q4_0:
|
2437
|
-
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
2438
|
-
break;
|
2439
|
-
case GGML_TYPE_Q4_1:
|
2440
|
-
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
2441
|
-
break;
|
2442
|
-
case GGML_TYPE_Q5_0:
|
2443
|
-
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
2444
|
-
break;
|
2445
|
-
case GGML_TYPE_Q5_1:
|
2446
|
-
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
2447
|
-
break;
|
2448
|
-
case GGML_TYPE_Q8_0:
|
2449
|
-
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
2450
|
-
break;
|
2451
|
-
default:
|
2452
|
-
// TODO: k-quants
|
2453
|
-
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
2454
|
-
GGML_ABORT("fatal error");
|
2455
|
-
break;
|
2456
|
-
}
|
2457
|
-
}
|
2458
|
-
|
2459
|
-
|
2460
|
-
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
2461
|
-
const ggml_tensor *src1, ggml_tensor *dst,
|
2462
|
-
const float *src0_d, const float *src1_d,
|
2463
|
-
float *dst_d,
|
2464
|
-
const queue_ptr &main_stream) {
|
2465
|
-
|
2466
|
-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
|
2467
|
-
|
2468
|
-
GGML_UNUSED(src1);
|
2469
|
-
GGML_UNUSED(src1_d);
|
2470
|
-
}
|
2471
|
-
|
2472
|
-
|
2473
2062
|
inline void ggml_sycl_op_mul_mat_sycl(
|
2474
2063
|
ggml_backend_sycl_context & ctx,
|
2475
2064
|
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
@@ -2484,33 +2073,31 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
2484
2073
|
|
2485
2074
|
const int64_t ne00 = src0->ne[0];
|
2486
2075
|
const int64_t ne10 = src1->ne[0];
|
2487
|
-
|
2076
|
+
GGML_ASSERT(ne00 == ne10);
|
2488
2077
|
|
2489
2078
|
const int64_t row_diff = row_high - row_low;
|
2490
2079
|
|
2491
2080
|
int id;
|
2492
2081
|
SYCL_CHECK(
|
2493
2082
|
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2494
|
-
|
2495
|
-
const int64_t ne0 = dst->ne[0];
|
2083
|
+
|
2084
|
+
const int64_t ne0 = dst->ne[0]; // used by MKL only
|
2496
2085
|
// the main device has a larger memory buffer to hold the results from all GPUs
|
2497
2086
|
// ldc == nrows of the matrix that cuBLAS writes into
|
2498
|
-
int ldc = id == ctx.device ? ne0 : row_diff;
|
2499
|
-
#endif
|
2087
|
+
int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
|
2500
2088
|
|
2501
2089
|
#ifdef GGML_SYCL_F16
|
2502
2090
|
bool use_fp16 = true; // TODO(Yu) SYCL capability check
|
2503
2091
|
#else
|
2504
2092
|
bool use_fp16 = false;
|
2505
2093
|
#endif
|
2506
|
-
if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
2507
|
-
|
2508
|
-
dst->op_params[0] == GGML_PREC_DEFAULT) {
|
2509
|
-
|
2510
|
-
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
|
2094
|
+
if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) &&
|
2095
|
+
row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
2511
2096
|
ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
|
2512
2097
|
if (src0->type != GGML_TYPE_F16) {
|
2513
|
-
|
2098
|
+
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
|
2099
|
+
" : converting src0 to fp16");
|
2100
|
+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
|
2514
2101
|
GGML_ASSERT(to_fp16_sycl != nullptr);
|
2515
2102
|
size_t ne = row_diff*ne00;
|
2516
2103
|
src0_as_f16.alloc(ne);
|
@@ -2522,7 +2109,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
2522
2109
|
|
2523
2110
|
ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
|
2524
2111
|
if (src1->type != GGML_TYPE_F16) {
|
2525
|
-
|
2112
|
+
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
|
2113
|
+
" : converting src1 to fp16");
|
2114
|
+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
2526
2115
|
GGML_ASSERT(to_fp16_sycl != nullptr);
|
2527
2116
|
size_t ne = src1_ncols*ne10;
|
2528
2117
|
src1_as_f16.alloc(ne);
|
@@ -2531,40 +2120,47 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
2531
2120
|
const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
|
2532
2121
|
? (const sycl::half *)src1->data + src1_padded_row_size
|
2533
2122
|
: src1_as_f16.get();
|
2534
|
-
|
2535
|
-
|
2536
|
-
|
2537
|
-
|
2538
|
-
|
2539
|
-
|
2540
|
-
|
2541
|
-
|
2542
|
-
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
2543
|
-
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
2544
|
-
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
2545
|
-
dpct::library_data_t::real_half)));
|
2546
|
-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
2547
|
-
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
2548
|
-
#else
|
2549
|
-
auto dnnl_stream = ctx.stream_dnnl(stream);
|
2550
|
-
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
2551
|
-
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
|
2552
|
-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
2553
|
-
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
2123
|
+
|
2124
|
+
#if GGML_SYCL_DNNL
|
2125
|
+
if (!g_ggml_sycl_disable_dnn) {
|
2126
|
+
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
2127
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
2128
|
+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
2129
|
+
}
|
2130
|
+
else
|
2554
2131
|
#endif
|
2555
|
-
|
2556
|
-
|
2557
|
-
|
2132
|
+
{
|
2133
|
+
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
2134
|
+
|
2135
|
+
const sycl::half alpha_f16 = 1.0f;
|
2136
|
+
const sycl::half beta_f16 = 0.0f;
|
2137
|
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
2138
|
+
*stream, oneapi::math::transpose::trans,
|
2139
|
+
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
2140
|
+
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
2141
|
+
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
2142
|
+
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
2143
|
+
dpct::library_data_t::real_half)));
|
2144
|
+
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
|
2145
|
+
" : converting dst to fp32");
|
2146
|
+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
2147
|
+
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
2148
|
+
}
|
2149
|
+
} else {
|
2558
2150
|
ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
|
2559
2151
|
ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
|
2560
2152
|
if (src0->type != GGML_TYPE_F32) {
|
2561
|
-
|
2153
|
+
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
|
2154
|
+
" : converting src0 to fp32");
|
2155
|
+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
|
2562
2156
|
GGML_ASSERT(to_fp32_sycl != nullptr);
|
2563
2157
|
src0_ddq_as_f32.alloc(row_diff*ne00);
|
2564
2158
|
to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
|
2565
2159
|
}
|
2566
2160
|
if (src1->type != GGML_TYPE_F32) {
|
2567
|
-
|
2161
|
+
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
|
2162
|
+
" : converting src1 to fp32");
|
2163
|
+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
|
2568
2164
|
GGML_ASSERT(to_fp32_sycl != nullptr);
|
2569
2165
|
src1_ddq_as_f32.alloc(src1_ncols*ne10);
|
2570
2166
|
to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
|
@@ -2572,25 +2168,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
2572
2168
|
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
2573
2169
|
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
2574
2170
|
|
2575
|
-
#if
|
2576
|
-
|
2577
|
-
|
2578
|
-
|
2579
|
-
|
2580
|
-
|
2581
|
-
|
2582
|
-
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
2583
|
-
# else
|
2584
|
-
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
2585
|
-
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
2586
|
-
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
2587
|
-
dst_dd_i, ldc)));
|
2588
|
-
# endif
|
2589
|
-
#else
|
2590
|
-
auto dnnl_stream = ctx.stream_dnnl(stream);
|
2591
|
-
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
2592
|
-
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
|
2171
|
+
#if GGML_SYCL_DNNL
|
2172
|
+
if (!g_ggml_sycl_disable_dnn) {
|
2173
|
+
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
2174
|
+
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
2175
|
+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
2176
|
+
}
|
2177
|
+
else
|
2593
2178
|
#endif
|
2179
|
+
{
|
2180
|
+
const float alpha = 1.0f;
|
2181
|
+
const float beta = 0.0f;
|
2182
|
+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
2183
|
+
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
2184
|
+
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
2185
|
+
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
2186
|
+
}
|
2594
2187
|
}
|
2595
2188
|
GGML_UNUSED(dst);
|
2596
2189
|
GGML_UNUSED(src1_ddq_i);
|
@@ -2602,13 +2195,13 @@ catch (sycl::exception const &exc) {
|
|
2602
2195
|
std::exit(1);
|
2603
2196
|
}
|
2604
2197
|
|
2605
|
-
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx,
|
2606
|
-
|
2607
|
-
const float *src0_dd, const float *src1_dd,
|
2608
|
-
float *dst_dd, const queue_ptr &main_stream) {
|
2609
|
-
|
2610
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2198
|
+
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2199
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
2611
2200
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
2201
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
2202
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
2203
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
2204
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
2612
2205
|
|
2613
2206
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
2614
2207
|
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
@@ -2619,8 +2212,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
2619
2212
|
const int p0 = opts[5];
|
2620
2213
|
const int p1 = opts[6];
|
2621
2214
|
|
2622
|
-
const int64_t IH =
|
2623
|
-
const int64_t IW =
|
2215
|
+
const int64_t IH = dst->src[0]->ne[1];
|
2216
|
+
const int64_t IW = dst->src[0]->ne[0];
|
2624
2217
|
|
2625
2218
|
const int64_t N = dst->ne[3];
|
2626
2219
|
const int64_t OC = dst->ne[2];
|
@@ -2639,163 +2232,101 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
2639
2232
|
parallel_elements, src0_dd, dst_dd, op,
|
2640
2233
|
item_ct1);
|
2641
2234
|
});
|
2642
|
-
|
2643
|
-
GGML_UNUSED(src1);
|
2644
|
-
GGML_UNUSED(src1_dd);
|
2645
|
-
GGML_UNUSED(ctx);
|
2646
2235
|
}
|
2647
2236
|
|
2648
|
-
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx,
|
2649
|
-
|
2650
|
-
const float *src0_dd, const float *src1_dd,
|
2651
|
-
float *dst_dd,
|
2652
|
-
const queue_ptr &main_stream) {
|
2653
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2237
|
+
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
2238
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
2654
2239
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
2240
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
2241
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
2242
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
2243
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
2655
2244
|
|
2656
|
-
const int64_t ne = ggml_nelements(
|
2245
|
+
const int64_t ne = ggml_nelements(dst->src[0]);
|
2657
2246
|
|
2658
2247
|
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
|
2659
|
-
|
2660
|
-
GGML_UNUSED(src1);
|
2661
|
-
GGML_UNUSED(dst);
|
2662
|
-
GGML_UNUSED(src1_dd);
|
2663
|
-
GGML_UNUSED(ctx);
|
2664
2248
|
}
|
2665
2249
|
|
2666
|
-
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx,
|
2667
|
-
|
2668
|
-
const float *src0_dd, const float *src1_dd,
|
2669
|
-
float *dst_dd,
|
2670
|
-
const queue_ptr &main_stream) {
|
2671
|
-
|
2672
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2250
|
+
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2251
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
2673
2252
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
2253
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
2254
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
2255
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
2256
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
2674
2257
|
|
2675
|
-
const int64_t ncols =
|
2676
|
-
const int64_t nrows = ggml_nrows(
|
2258
|
+
const int64_t ncols = dst->src[0]->ne[0];
|
2259
|
+
const int64_t nrows = ggml_nrows(dst->src[0]);
|
2677
2260
|
|
2678
2261
|
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
2679
|
-
|
2680
|
-
GGML_UNUSED(src1);
|
2681
|
-
GGML_UNUSED(dst);
|
2682
|
-
GGML_UNUSED(src1_dd);
|
2683
|
-
GGML_UNUSED(ctx);
|
2684
2262
|
}
|
2685
2263
|
|
2686
|
-
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx,
|
2687
|
-
|
2688
|
-
|
2689
|
-
|
2690
|
-
|
2264
|
+
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2265
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
2266
|
+
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
2267
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
2268
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
2269
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
2270
|
+
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
2691
2271
|
|
2692
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2693
|
-
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
2694
2272
|
|
2695
|
-
const int64_t ncols =
|
2696
|
-
const int64_t nrows = ggml_nrows(
|
2273
|
+
const int64_t ncols = dst->src[0]->ne[0];
|
2274
|
+
const int64_t nrows = ggml_nrows(dst->src[0]);
|
2697
2275
|
|
2698
2276
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
2699
2277
|
|
2700
|
-
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
|
2701
|
-
|
2702
|
-
GGML_UNUSED(src1);
|
2703
|
-
GGML_UNUSED(dst);
|
2704
|
-
GGML_UNUSED(src1_dd);
|
2705
|
-
GGML_UNUSED(ctx);
|
2278
|
+
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
|
2706
2279
|
}
|
2707
2280
|
|
2708
|
-
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx,
|
2709
|
-
|
2710
|
-
const float *src0_dd, const float *src1_dd,
|
2711
|
-
float *dst_dd,
|
2712
|
-
const queue_ptr &main_stream) {
|
2713
|
-
|
2714
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2281
|
+
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2282
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
2715
2283
|
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
2716
2284
|
|
2717
|
-
|
2718
|
-
|
2285
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
2286
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
2287
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
2288
|
+
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
2719
2289
|
|
2720
|
-
|
2290
|
+
const int64_t ncols = dst->src[0]->ne[0];
|
2291
|
+
const int64_t nrows = ggml_nrows(dst->src[0]);
|
2721
2292
|
|
2722
|
-
|
2723
|
-
GGML_UNUSED(dst);
|
2724
|
-
GGML_UNUSED(src1_dd);
|
2725
|
-
GGML_UNUSED(ctx);
|
2293
|
+
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
2726
2294
|
}
|
2727
2295
|
|
2728
|
-
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,
|
2729
|
-
|
2730
|
-
ggml_tensor *dst, const float *src0_dd,
|
2731
|
-
const float *src1_dd, float *dst_dd,
|
2732
|
-
const queue_ptr &main_stream) {
|
2733
|
-
|
2734
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2296
|
+
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2297
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
2735
2298
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
2299
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
2300
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
2301
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
2302
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
2736
2303
|
|
2737
|
-
const int64_t ne00 =
|
2738
|
-
const int64_t ne01 =
|
2739
|
-
const int nrows0 = ggml_nrows(
|
2304
|
+
const int64_t ne00 = dst->src[0]->ne[0];
|
2305
|
+
const int64_t ne01 = dst->src[0]->ne[1];
|
2306
|
+
const int nrows0 = ggml_nrows(dst->src[0]);
|
2740
2307
|
|
2741
2308
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
2742
2309
|
|
2743
2310
|
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
|
2744
|
-
|
2745
|
-
GGML_UNUSED(src1);
|
2746
|
-
GGML_UNUSED(dst);
|
2747
|
-
GGML_UNUSED(src1_dd);
|
2748
|
-
GGML_UNUSED(ctx);
|
2749
2311
|
}
|
2750
2312
|
|
2751
|
-
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx,
|
2752
|
-
|
2753
|
-
const float *src1_dd, float *dst_dd,
|
2754
|
-
const queue_ptr &main_stream) {
|
2755
|
-
|
2756
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2313
|
+
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2314
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
2757
2315
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
2316
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
2317
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
2318
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
2319
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
2758
2320
|
|
2759
2321
|
float scale;
|
2760
2322
|
memcpy(&scale, dst->op_params, sizeof(float));
|
2761
2323
|
|
2762
|
-
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(
|
2324
|
+
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
|
2763
2325
|
/*
|
2764
2326
|
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
|
2765
2327
|
error codes. The call was replaced with 0. You need to rewrite this code.
|
2766
2328
|
*/
|
2767
2329
|
SYCL_CHECK(0);
|
2768
|
-
|
2769
|
-
GGML_UNUSED(src1);
|
2770
|
-
GGML_UNUSED(dst);
|
2771
|
-
GGML_UNUSED(src1_dd);
|
2772
|
-
GGML_UNUSED(ctx);
|
2773
|
-
}
|
2774
|
-
|
2775
|
-
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
2776
|
-
ggml_tensor *dst, const float *src0_dd,
|
2777
|
-
const float *src1_dd, float *dst_dd,
|
2778
|
-
const queue_ptr &main_stream) {
|
2779
|
-
|
2780
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2781
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
2782
|
-
|
2783
|
-
float min;
|
2784
|
-
float max;
|
2785
|
-
memcpy(&min, dst->op_params, sizeof(float));
|
2786
|
-
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
2787
|
-
|
2788
|
-
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
|
2789
|
-
/*
|
2790
|
-
DPCT1010:88: SYCL uses exceptions to report errors and does not use the
|
2791
|
-
error codes. The call was replaced with 0. You need to rewrite this code.
|
2792
|
-
*/
|
2793
|
-
SYCL_CHECK(0);
|
2794
|
-
|
2795
|
-
GGML_UNUSED(src1);
|
2796
|
-
GGML_UNUSED(dst);
|
2797
|
-
GGML_UNUSED(src1_dd);
|
2798
|
-
GGML_UNUSED(ctx);
|
2799
2330
|
}
|
2800
2331
|
|
2801
2332
|
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
@@ -2857,8 +2388,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2857
2388
|
const int nb2 = dst->nb[2];
|
2858
2389
|
const int nb3 = dst->nb[3];
|
2859
2390
|
|
2860
|
-
GGML_ASSERT(dst->
|
2861
|
-
GGML_ASSERT(src1->
|
2391
|
+
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
|
2392
|
+
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer));
|
2862
2393
|
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
|
2863
2394
|
|
2864
2395
|
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
|
@@ -2878,7 +2409,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2878
2409
|
|
2879
2410
|
int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
|
2880
2411
|
|
2881
|
-
const bool split = src0->
|
2412
|
+
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
|
2882
2413
|
GGML_ASSERT(!(split && ne02 > 1));
|
2883
2414
|
GGML_ASSERT(!(split && ne03 > 1));
|
2884
2415
|
GGML_ASSERT(!(split && ne02 < ne12));
|
@@ -2966,7 +2497,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2966
2497
|
dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
|
2967
2498
|
|
2968
2499
|
if (src1_on_device && src1_is_contiguous) {
|
2969
|
-
|
2500
|
+
bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
|
2501
|
+
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
|
2502
|
+
/*num_src=*/2, " : converting src1 to Q8_1");
|
2503
|
+
quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
|
2970
2504
|
/*
|
2971
2505
|
DPCT1010:90: SYCL uses exceptions to report errors and does not
|
2972
2506
|
use the error codes. The call was replaced with 0. You need to
|
@@ -3002,7 +2536,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
3002
2536
|
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
|
3003
2537
|
const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
|
3004
2538
|
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
|
3005
|
-
|
3006
2539
|
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
|
3007
2540
|
if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
|
3008
2541
|
continue;
|
@@ -3071,7 +2604,9 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
3071
2604
|
}
|
3072
2605
|
|
3073
2606
|
if (convert_src1_to_q8_1 && !src1_is_contiguous) {
|
3074
|
-
|
2607
|
+
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
|
2608
|
+
/*num_src=*/2, " : converting src1 to Q8_1");
|
2609
|
+
quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
|
3075
2610
|
/*
|
3076
2611
|
DPCT1010:92: SYCL uses exceptions to report errors and does
|
3077
2612
|
not use the error codes. The call was replaced with 0. You
|
@@ -3164,41 +2699,36 @@ catch (sycl::exception const &exc) {
|
|
3164
2699
|
}
|
3165
2700
|
|
3166
2701
|
|
3167
|
-
static void
|
3168
|
-
|
3169
|
-
|
3170
|
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
2702
|
+
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2703
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
2704
|
+
ggml_sycl_op_get_rows(ctx, dst);
|
3171
2705
|
}
|
3172
2706
|
|
3173
|
-
static void
|
3174
|
-
|
3175
|
-
|
3176
|
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
2707
|
+
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2708
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
2709
|
+
ggml_sycl_op_norm(ctx, dst);
|
3177
2710
|
}
|
3178
2711
|
|
3179
|
-
static void
|
3180
|
-
|
3181
|
-
|
3182
|
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
2712
|
+
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2713
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
2714
|
+
ggml_sycl_op_rms_norm(ctx, dst);
|
3183
2715
|
}
|
3184
2716
|
|
3185
|
-
static void
|
3186
|
-
|
3187
|
-
|
3188
|
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
2717
|
+
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2718
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
2719
|
+
ggml_sycl_op_l2_norm(ctx, dst);
|
3189
2720
|
}
|
3190
2721
|
|
3191
|
-
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx,
|
3192
|
-
|
3193
|
-
|
3194
|
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
2722
|
+
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
2723
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
2724
|
+
ggml_sycl_op_group_norm(ctx, dst);
|
3195
2725
|
}
|
3196
2726
|
|
3197
2727
|
static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
3198
2728
|
const ggml_tensor *src1,
|
3199
2729
|
ggml_tensor *dst) try {
|
3200
2730
|
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
|
3201
|
-
GGML_ASSERT(src0->
|
2731
|
+
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
3202
2732
|
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
|
3203
2733
|
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
|
3204
2734
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
@@ -3231,7 +2761,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
3231
2761
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
3232
2762
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
3233
2763
|
GGML_ASSERT(!ggml_is_permuted(src0));
|
3234
|
-
GGML_ASSERT(src0->
|
2764
|
+
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
3235
2765
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
3236
2766
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
3237
2767
|
|
@@ -3262,146 +2792,182 @@ catch (sycl::exception const &exc) {
|
|
3262
2792
|
std::exit(1);
|
3263
2793
|
}
|
3264
2794
|
|
3265
|
-
static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
|
3266
|
-
const
|
3267
|
-
|
3268
|
-
int64_t
|
3269
|
-
|
3270
|
-
|
3271
|
-
int64_t r2, int64_t r3,
|
3272
|
-
const sycl::nd_item<3> &item_ct1) {
|
3273
|
-
int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
3274
|
-
item_ct1.get_local_id(2);
|
3275
|
-
int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
|
3276
|
-
item_ct1.get_local_id(1);
|
2795
|
+
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
|
2796
|
+
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
2797
|
+
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
2798
|
+
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
2799
|
+
const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
|
2800
|
+
const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
|
3277
2801
|
|
3278
2802
|
if (i13 >= ne13 || i12 >= ne12) {
|
3279
2803
|
return;
|
3280
2804
|
}
|
3281
2805
|
|
3282
|
-
int64_t i03 = i13 / r3;
|
3283
|
-
int64_t i02 = i12 / r2;
|
2806
|
+
const int64_t i03 = i13 / r3;
|
2807
|
+
const int64_t i02 = i12 / r2;
|
3284
2808
|
|
3285
|
-
|
3286
|
-
|
3287
|
-
|
2809
|
+
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
2810
|
+
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
2811
|
+
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
|
2812
|
+
|
2813
|
+
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
2814
|
+
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
2815
|
+
ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
|
3288
2816
|
}
|
3289
2817
|
|
3290
|
-
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
3291
|
-
|
3292
|
-
const ggml_tensor *src1,
|
3293
|
-
ggml_tensor *dst) try {
|
2818
|
+
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
|
2819
|
+
const ggml_tensor * src1, ggml_tensor * dst) try {
|
3294
2820
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
3295
2821
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
3296
|
-
GGML_ASSERT(src0->
|
2822
|
+
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
3297
2823
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
2824
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
3298
2825
|
|
3299
2826
|
GGML_TENSOR_BINARY_OP_LOCALS
|
3300
2827
|
|
2828
|
+
// TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
|
2829
|
+
// Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
|
2830
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
3301
2831
|
|
3302
2832
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
3303
|
-
queue_ptr
|
2833
|
+
queue_ptr queue = ctx.stream();
|
3304
2834
|
|
3305
|
-
|
3306
|
-
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
|
3307
|
-
float * src1_ddf = (float *) src1->data;
|
3308
|
-
float * dst_ddf = (float *) dst->data;
|
2835
|
+
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
|
3309
2836
|
|
3310
|
-
|
2837
|
+
const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
|
2838
|
+
float * dst_ddf = static_cast<float *>(dst->data);
|
2839
|
+
|
2840
|
+
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
2841
|
+
const size_t type_size_src1 = ggml_type_size(src1->type);
|
2842
|
+
GGML_ASSERT(nb10 == type_size_src1);
|
2843
|
+
|
2844
|
+
// SRC1 strides
|
2845
|
+
int64_t s11 = nb11 / type_size_src1;
|
2846
|
+
int64_t s12 = nb12 / type_size_src1;
|
2847
|
+
int64_t s13 = nb13 / type_size_src1;
|
3311
2848
|
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
|
2849
|
+
|
2850
|
+
// convert src1 to fp16
|
3312
2851
|
if (src1->type != GGML_TYPE_F16) {
|
3313
|
-
|
2852
|
+
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
|
2853
|
+
" : converting src1 to fp16");
|
2854
|
+
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
2855
|
+
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
3314
2856
|
const int64_t ne_src1 = ggml_nelements(src1);
|
3315
2857
|
src1_f16_alloc.alloc(ne_src1);
|
3316
|
-
|
3317
|
-
|
2858
|
+
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
2859
|
+
|
2860
|
+
src1_f16 = src1_f16_alloc.get();
|
2861
|
+
s11 = ne10;
|
2862
|
+
s12 = ne11 * s11;
|
2863
|
+
s13 = ne12 * s12;
|
3318
2864
|
}
|
3319
|
-
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
|
3320
|
-
: src1_f16_alloc.get();
|
3321
2865
|
|
3322
|
-
|
2866
|
+
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
3323
2867
|
|
3324
|
-
dpct::library_data_t
|
3325
|
-
dpct::library_data_t
|
2868
|
+
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
2869
|
+
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
3326
2870
|
|
3327
2871
|
// dst strides
|
3328
2872
|
size_t nbd2 = dst->nb[2];
|
3329
2873
|
size_t nbd3 = dst->nb[3];
|
3330
2874
|
|
3331
2875
|
const float alpha_f32 = 1.0f;
|
3332
|
-
const float beta_f32
|
2876
|
+
const float beta_f32 = 0.0f;
|
3333
2877
|
|
3334
2878
|
const void * alpha = &alpha_f32;
|
3335
2879
|
const void * beta = &beta_f32;
|
3336
2880
|
|
3337
|
-
dst_t = (char *) dst_ddf;
|
3338
|
-
|
3339
2881
|
GGML_ASSERT(ne12 % ne02 == 0);
|
3340
2882
|
GGML_ASSERT(ne13 % ne03 == 0);
|
2883
|
+
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
|
2884
|
+
GGML_ASSERT(ne10 == ne00);
|
3341
2885
|
|
3342
2886
|
// broadcast factors
|
3343
|
-
const int64_t r2 = ne12/ne02;
|
3344
|
-
const int64_t r3 = ne13/ne03;
|
3345
|
-
|
3346
|
-
|
3347
|
-
|
3348
|
-
|
3349
|
-
*
|
3350
|
-
|
3351
|
-
(
|
3352
|
-
|
3353
|
-
|
3354
|
-
|
3355
|
-
|
3356
|
-
|
3357
|
-
|
3358
|
-
|
3359
|
-
|
3360
|
-
|
3361
|
-
|
3362
|
-
|
3363
|
-
|
3364
|
-
|
3365
|
-
|
3366
|
-
|
3367
|
-
|
3368
|
-
|
3369
|
-
{
|
3370
|
-
|
3371
|
-
|
3372
|
-
|
3373
|
-
|
3374
|
-
|
3375
|
-
|
3376
|
-
|
3377
|
-
|
3378
|
-
|
3379
|
-
|
3380
|
-
|
3381
|
-
|
3382
|
-
|
3383
|
-
|
3384
|
-
|
3385
|
-
|
3386
|
-
|
2887
|
+
const int64_t r2 = ne12 / ne02;
|
2888
|
+
const int64_t r3 = ne13 / ne03;
|
2889
|
+
|
2890
|
+
#if GGML_SYCL_DNNL
|
2891
|
+
if (!g_ggml_sycl_disable_dnn) {
|
2892
|
+
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
|
2893
|
+
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
|
2894
|
+
|
2895
|
+
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
|
2896
|
+
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
|
2897
|
+
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
|
2898
|
+
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
|
2899
|
+
};
|
2900
|
+
|
2901
|
+
if (r2 == 1 && r3 == 1) {
|
2902
|
+
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
2903
|
+
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
|
2904
|
+
}
|
2905
|
+
else {
|
2906
|
+
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
2907
|
+
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
|
2908
|
+
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
|
2909
|
+
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
|
2910
|
+
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
|
2911
|
+
}
|
2912
|
+
}
|
2913
|
+
} else {
|
2914
|
+
// iterate over batches from smaller set of matrices (matrix 0)
|
2915
|
+
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
|
2916
|
+
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
2917
|
+
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
|
2918
|
+
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
|
2919
|
+
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
|
2920
|
+
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
|
2921
|
+
}
|
2922
|
+
}
|
2923
|
+
}
|
2924
|
+
}
|
2925
|
+
else
|
2926
|
+
#endif
|
2927
|
+
{
|
2928
|
+
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
2929
|
+
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
2930
|
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
2931
|
+
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
2932
|
+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
2933
|
+
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
|
2934
|
+
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
2935
|
+
} else {
|
2936
|
+
const int ne23 = ne12 * ne13;
|
2937
|
+
|
2938
|
+
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
2939
|
+
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
2940
|
+
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
2941
|
+
|
2942
|
+
sycl::range<3> block_dims(1, ne12, ne13);
|
2943
|
+
queue->submit([&](sycl::handler & cgh) {
|
2944
|
+
const void ** ptrs_src_get = ptrs_src.get();
|
2945
|
+
void ** ptrs_dst_get = ptrs_dst.get();
|
2946
|
+
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
2947
|
+
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
2948
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
2949
|
+
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
2950
|
+
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
2951
|
+
});
|
3387
2952
|
});
|
2953
|
+
|
2954
|
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
2955
|
+
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
2956
|
+
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
2957
|
+
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
2958
|
+
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
3388
2959
|
}
|
3389
|
-
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
3390
|
-
*main_stream, oneapi::mkl::transpose::trans,
|
3391
|
-
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
3392
|
-
(const void **)(ptrs_src.get() + 0 * ne23),
|
3393
|
-
dpct::library_data_t::real_half, nb01 / nb00,
|
3394
|
-
(const void **)(ptrs_src.get() + 1 * ne23),
|
3395
|
-
dpct::library_data_t::real_half, nb11 / nb10, beta,
|
3396
|
-
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
|
3397
|
-
cu_compute_type)));
|
3398
2960
|
}
|
2961
|
+
} catch (const sycl::exception & exc) {
|
2962
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
2963
|
+
std::exit(1);
|
3399
2964
|
}
|
3400
|
-
|
3401
|
-
|
3402
|
-
|
3403
|
-
|
3404
|
-
|
2965
|
+
|
2966
|
+
enum class mul_mat_algo {
|
2967
|
+
DMMV = 0,
|
2968
|
+
MMVQ = 1,
|
2969
|
+
MUL_MAT_SYCL = 2,
|
2970
|
+
};
|
3405
2971
|
|
3406
2972
|
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
3407
2973
|
// TODO: accuracy issues in MMQ
|
@@ -3409,7 +2975,39 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
|
3409
2975
|
return false;
|
3410
2976
|
}
|
3411
2977
|
|
3412
|
-
bool
|
2978
|
+
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
2979
|
+
switch (type) {
|
2980
|
+
case GGML_TYPE_Q4_0:
|
2981
|
+
return true;
|
2982
|
+
case GGML_TYPE_Q4_K:
|
2983
|
+
case GGML_TYPE_Q6_K:
|
2984
|
+
return !g_ggml_sycl_prioritize_dmmv;
|
2985
|
+
default:
|
2986
|
+
return false;
|
2987
|
+
}
|
2988
|
+
}
|
2989
|
+
|
2990
|
+
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
2991
|
+
switch (type) {
|
2992
|
+
case GGML_TYPE_Q4_0:
|
2993
|
+
return true;
|
2994
|
+
default:
|
2995
|
+
return false;
|
2996
|
+
}
|
2997
|
+
}
|
2998
|
+
|
2999
|
+
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
3000
|
+
switch (type) {
|
3001
|
+
case GGML_TYPE_Q4_0:
|
3002
|
+
case GGML_TYPE_Q4_K:
|
3003
|
+
case GGML_TYPE_Q6_K:
|
3004
|
+
return true;
|
3005
|
+
default:
|
3006
|
+
return false;
|
3007
|
+
}
|
3008
|
+
}
|
3009
|
+
|
3010
|
+
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
3413
3011
|
switch (type) {
|
3414
3012
|
case GGML_TYPE_Q4_0:
|
3415
3013
|
case GGML_TYPE_Q4_1:
|
@@ -3428,12 +3026,190 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
|
3428
3026
|
}
|
3429
3027
|
}
|
3430
3028
|
|
3029
|
+
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
|
3030
|
+
dpct::queue_ptr stream) {
|
3031
|
+
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
3032
|
+
SYCL_CHECK(
|
3033
|
+
CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
|
3034
|
+
.wait()));
|
3035
|
+
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
3036
|
+
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
3037
|
+
int offset_blks = offset / sizeof(block_q4_0);
|
3038
|
+
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
|
3039
|
+
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
3040
|
+
|
3041
|
+
stream->parallel_for(
|
3042
|
+
size / sizeof(block_q4_0),
|
3043
|
+
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
3044
|
+
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
3045
|
+
const int ib = i;
|
3046
|
+
|
3047
|
+
for (int j = 0; j < QK4_0/2; j ++)
|
3048
|
+
{
|
3049
|
+
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
|
3050
|
+
}
|
3051
|
+
*(d_ptr + ib) = x[ib].d;
|
3052
|
+
}).wait_and_throw();
|
3053
|
+
|
3054
|
+
sycl::free(tmp_buf, *stream);
|
3055
|
+
}
|
3056
|
+
|
3057
|
+
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
3058
|
+
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
|
3059
|
+
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
|
3060
|
+
|
3061
|
+
const int nblocks = size / sizeof(block_q4_K);
|
3062
|
+
|
3063
|
+
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
3064
|
+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
3065
|
+
|
3066
|
+
auto * qs_ptr = data_device;
|
3067
|
+
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
|
3068
|
+
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
|
3069
|
+
|
3070
|
+
stream->parallel_for(nblocks, [=](auto i) {
|
3071
|
+
const block_q4_K * x = (const block_q4_K *) tmp_buf;
|
3072
|
+
const int ib = i;
|
3073
|
+
|
3074
|
+
for (int j = 0; j < QK_K / 2; ++j) {
|
3075
|
+
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
|
3076
|
+
}
|
3077
|
+
|
3078
|
+
for (int j = 0; j < K_SCALE_SIZE; ++j) {
|
3079
|
+
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
|
3080
|
+
}
|
3081
|
+
|
3082
|
+
dm_ptr[ib] = x[ib].dm;
|
3083
|
+
}).wait_and_throw();
|
3084
|
+
|
3085
|
+
sycl::free(tmp_buf, *stream);
|
3086
|
+
}
|
3087
|
+
|
3088
|
+
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
3089
|
+
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
|
3090
|
+
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
|
3091
|
+
|
3092
|
+
const int nblocks = size / sizeof(block_q6_K);
|
3093
|
+
|
3094
|
+
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
3095
|
+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
3096
|
+
|
3097
|
+
auto * ql_ptr = data_device;
|
3098
|
+
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
|
3099
|
+
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
|
3100
|
+
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
|
3101
|
+
|
3102
|
+
stream
|
3103
|
+
->parallel_for(nblocks,
|
3104
|
+
[=](auto i) {
|
3105
|
+
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
3106
|
+
const int ib = i;
|
3107
|
+
|
3108
|
+
const uint8_t * ql = x[ib].ql;
|
3109
|
+
const uint8_t * qh = x[ib].qh;
|
3110
|
+
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
|
3111
|
+
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
|
3112
|
+
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
|
3113
|
+
|
3114
|
+
for (int j = 0; j < QK_K / 2; ++j) {
|
3115
|
+
base_ql_ptr[j] = ql[j];
|
3116
|
+
}
|
3117
|
+
for (int j = 0; j < QK_K / 4; ++j) {
|
3118
|
+
base_qh_ptr[j] = qh[j];
|
3119
|
+
}
|
3120
|
+
|
3121
|
+
for (int j = 0; j < QK_K / 16; ++j) {
|
3122
|
+
base_scales_ptr[j] = x[ib].scales[j];
|
3123
|
+
}
|
3124
|
+
|
3125
|
+
dm_ptr[ib] = x[ib].d;
|
3126
|
+
})
|
3127
|
+
.wait_and_throw();
|
3128
|
+
|
3129
|
+
sycl::free(tmp_buf, *stream);
|
3130
|
+
}
|
3131
|
+
|
3132
|
+
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
3133
|
+
uint8_t * data_device = (uint8_t *) src0->data;
|
3134
|
+
size_t ncols = src0->ne[0];
|
3135
|
+
size_t nrows = src0->ne[1];
|
3136
|
+
size_t size = ggml_nbytes(src0);
|
3137
|
+
|
3138
|
+
switch (src0->type) {
|
3139
|
+
case GGML_TYPE_Q4_0:
|
3140
|
+
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
|
3141
|
+
break;
|
3142
|
+
case GGML_TYPE_Q4_K:
|
3143
|
+
reorder_qw_q4_k(data_device, size, 0, stream);
|
3144
|
+
break;
|
3145
|
+
case GGML_TYPE_Q6_K:
|
3146
|
+
reorder_qw_q6_k(data_device, size, 0, stream);
|
3147
|
+
break;
|
3148
|
+
default:
|
3149
|
+
GGML_ABORT("reorder_qw() called with unsupported type");
|
3150
|
+
break;
|
3151
|
+
}
|
3152
|
+
}
|
3153
|
+
|
3154
|
+
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
|
3155
|
+
return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
3156
|
+
ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
3157
|
+
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
3158
|
+
dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
|
3159
|
+
}
|
3160
|
+
|
3161
|
+
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
|
3162
|
+
ggml_tensor * dst, mul_mat_algo mm_algorithm) {
|
3163
|
+
if (!should_reorder_tensor(*ctx, dst)) {
|
3164
|
+
return;
|
3165
|
+
}
|
3166
|
+
|
3167
|
+
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
|
3168
|
+
if (!extra || extra->optimized_feature.reorder) {
|
3169
|
+
return; // Skip permutations and already reordered tensors
|
3170
|
+
}
|
3171
|
+
|
3172
|
+
switch (mm_algorithm) {
|
3173
|
+
case mul_mat_algo::DMMV:
|
3174
|
+
if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
|
3175
|
+
return;
|
3176
|
+
}
|
3177
|
+
break;
|
3178
|
+
case mul_mat_algo::MMVQ:
|
3179
|
+
if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
|
3180
|
+
return;
|
3181
|
+
}
|
3182
|
+
break;
|
3183
|
+
case mul_mat_algo::MUL_MAT_SYCL:
|
3184
|
+
if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
|
3185
|
+
return;
|
3186
|
+
}
|
3187
|
+
break;
|
3188
|
+
}
|
3189
|
+
|
3190
|
+
reorder_qw(src0, ctx->stream());
|
3191
|
+
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
|
3192
|
+
}
|
3193
|
+
|
3194
|
+
|
3195
|
+
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
3196
|
+
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
3197
|
+
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
3198
|
+
}
|
3199
|
+
|
3200
|
+
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
3201
|
+
return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
3202
|
+
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
3203
|
+
}
|
3204
|
+
|
3431
3205
|
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
3206
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
3432
3207
|
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
|
3433
3208
|
int64_t min_compute_capability = INT_MAX;
|
3434
3209
|
|
3435
3210
|
if (split) {
|
3436
|
-
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
|
3211
|
+
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
|
3212
|
+
(ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
|
3437
3213
|
auto & tensor_split = buft_ctx->tensor_split;
|
3438
3214
|
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
|
3439
3215
|
// skip devices that are not going to do any work:
|
@@ -3446,17 +3222,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
3446
3222
|
}
|
3447
3223
|
}
|
3448
3224
|
} else {
|
3449
|
-
min_compute_capability
|
3225
|
+
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
|
3450
3226
|
}
|
3451
3227
|
|
3452
3228
|
// check data types and tensor shapes for custom matrix multiplication kernels:
|
3453
|
-
bool use_dequantize_mul_mat_vec =
|
3454
|
-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
3455
|
-
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
3229
|
+
bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
|
3456
3230
|
|
3457
|
-
bool use_mul_mat_vec_q =
|
3458
|
-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
3459
|
-
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
3231
|
+
bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
|
3460
3232
|
|
3461
3233
|
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
3462
3234
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
@@ -3468,9 +3240,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
3468
3240
|
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
3469
3241
|
#endif // SYCL_USE_XMX
|
3470
3242
|
|
3243
|
+
|
3471
3244
|
// mmvq path is faster in the CUDA backend.
|
3472
|
-
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
|
3245
|
+
if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
|
3246
|
+
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
|
3247
|
+
// is enabled takes precedence over DMMV, the current if-else implementation
|
3248
|
+
// requires disabling DMMV if both conditions are met
|
3249
|
+
|| (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
|
3473
3250
|
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
3251
|
+
}
|
3474
3252
|
|
3475
3253
|
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
3476
3254
|
// TODO: Refactor and cleanup of mul mat dispatching.
|
@@ -3482,20 +3260,26 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
3482
3260
|
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
3483
3261
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
3484
3262
|
}
|
3485
|
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
3263
|
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
3486
3264
|
// KQV single-batch
|
3487
3265
|
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
3488
3266
|
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
3489
3267
|
// KQ + KQV multi-batch
|
3490
3268
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
3491
3269
|
} else if (use_dequantize_mul_mat_vec) {
|
3492
|
-
|
3270
|
+
constexpr bool convert_src1_to_q8_1 = false;
|
3271
|
+
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
|
3272
|
+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
|
3493
3273
|
} else if (use_mul_mat_vec_q) {
|
3494
|
-
|
3274
|
+
constexpr bool convert_src1_to_q8_1 = true;
|
3275
|
+
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
|
3276
|
+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
|
3495
3277
|
} else if (use_mul_mat_q) {
|
3496
|
-
|
3278
|
+
constexpr bool convert_src1_to_q8_1 = true;
|
3279
|
+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
3497
3280
|
} else {
|
3498
|
-
|
3281
|
+
constexpr bool convert_src1_to_q8_1 = false;
|
3282
|
+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
3499
3283
|
}
|
3500
3284
|
}
|
3501
3285
|
|
@@ -3565,9 +3349,11 @@ __dpct_inline__ static void k_copy_dst_from_contiguous(
|
|
3565
3349
|
}
|
3566
3350
|
}
|
3567
3351
|
|
3568
|
-
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
3569
|
-
const ggml_tensor *src1,
|
3352
|
+
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
3570
3353
|
ggml_tensor *dst) try {
|
3354
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
|
3355
|
+
const ggml_tensor *src0 = dst->src[0];
|
3356
|
+
const ggml_tensor *src1 = dst->src[1];
|
3571
3357
|
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
|
3572
3358
|
|
3573
3359
|
const ggml_tensor *ids = dst->src[2];
|
@@ -3621,8 +3407,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
3621
3407
|
const int64_t i2 = i12;
|
3622
3408
|
|
3623
3409
|
src0_row.data = src0_original + i02*nb02;
|
3624
|
-
src1_row.data = src1_original +
|
3625
|
-
dst_row.data = dst_original + i1*nb1
|
3410
|
+
src1_row.data = src1_original + i11*nb11 + i12*nb12;
|
3411
|
+
dst_row.data = dst_original + i1*nb1 + i2*nb2;
|
3626
3412
|
|
3627
3413
|
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
3628
3414
|
}
|
@@ -3663,7 +3449,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
3663
3449
|
{
|
3664
3450
|
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
|
3665
3451
|
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
|
3666
|
-
stream
|
3452
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
3667
3453
|
sycl::local_accessor<int, 0> src1_row_acc(cgh);
|
3668
3454
|
|
3669
3455
|
char *__restrict src1_contiguous_get =
|
@@ -3675,9 +3461,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
3675
3461
|
size_t ids_nb_ct6 = ids->nb[1];
|
3676
3462
|
size_t ids_nb_ct7 = ids->nb[0];
|
3677
3463
|
|
3678
|
-
|
3679
|
-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
3680
|
-
[=](sycl::nd_item<3> item_ct1) {
|
3464
|
+
sycl_parallel_for(
|
3465
|
+
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
3681
3466
|
k_copy_src1_to_contiguous(
|
3682
3467
|
src1_original, src1_contiguous_get,
|
3683
3468
|
dev_cur_src1_row_get,
|
@@ -3708,15 +3493,14 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
3708
3493
|
{
|
3709
3494
|
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
|
3710
3495
|
sycl::range<3> grid_dims(1, 1, num_src1_rows);
|
3711
|
-
stream
|
3496
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
3712
3497
|
const char *__restrict dst_contiguous_get =
|
3713
3498
|
dst_contiguous.get();
|
3714
3499
|
const mmid_row_mapping *__restrict dev_row_mapping_get =
|
3715
3500
|
dev_row_mapping.get();
|
3716
3501
|
|
3717
|
-
|
3718
|
-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
3719
|
-
[=](sycl::nd_item<3> item_ct1) {
|
3502
|
+
sycl_parallel_for(
|
3503
|
+
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
3720
3504
|
k_copy_dst_from_contiguous(dst_original,
|
3721
3505
|
dst_contiguous_get,
|
3722
3506
|
dev_row_mapping_get,
|
@@ -3733,117 +3517,52 @@ catch (sycl::exception const &exc) {
|
|
3733
3517
|
std::exit(1);
|
3734
3518
|
}
|
3735
3519
|
|
3736
|
-
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx,
|
3737
|
-
|
3738
|
-
|
3739
|
-
|
3740
|
-
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
3741
|
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
|
3742
|
-
}
|
3743
|
-
|
3744
|
-
static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
3745
|
-
ggml_tensor *dst) try {
|
3746
|
-
const int64_t ne = ggml_nelements(src0);
|
3747
|
-
GGML_ASSERT(ne == ggml_nelements(src1));
|
3748
|
-
|
3749
|
-
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
|
3750
|
-
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
|
3751
|
-
|
3752
|
-
GGML_TENSOR_BINARY_OP_LOCALS01;
|
3753
|
-
|
3754
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
3755
|
-
queue_ptr main_stream = ctx.stream();
|
3756
|
-
|
3757
|
-
char * src0_ddc = (char *) src0->data;
|
3758
|
-
char * src1_ddc = (char *) src1->data;
|
3759
|
-
|
3760
|
-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
3761
|
-
ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3762
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
3763
|
-
ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3764
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
3765
|
-
ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3766
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
3767
|
-
ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3768
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
3769
|
-
ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3770
|
-
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
3771
|
-
ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3772
|
-
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
3773
|
-
ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3774
|
-
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
|
3775
|
-
ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3776
|
-
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
3777
|
-
ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
3778
|
-
} else {
|
3779
|
-
GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__,
|
3780
|
-
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
3781
|
-
GGML_ABORT("fatal error");
|
3782
|
-
}
|
3783
|
-
|
3784
|
-
GGML_UNUSED(dst);
|
3785
|
-
}
|
3786
|
-
catch (sycl::exception const &exc) {
|
3787
|
-
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
3788
|
-
<< ", line:" << __LINE__ << std::endl;
|
3789
|
-
std::exit(1);
|
3790
|
-
}
|
3791
|
-
|
3792
|
-
static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
3793
|
-
// TODO: why do we pass dst as src1 here?
|
3794
|
-
ggml_sycl_cpy(ctx, src0, dst, nullptr);
|
3795
|
-
GGML_UNUSED(src1);
|
3796
|
-
}
|
3797
|
-
|
3798
|
-
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
3799
|
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
|
3800
|
-
}
|
3801
|
-
|
3802
|
-
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
3803
|
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
|
3520
|
+
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
3521
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
3522
|
+
ggml_sycl_op_scale(ctx, dst);
|
3804
3523
|
}
|
3805
3524
|
|
3806
|
-
static void
|
3807
|
-
|
3808
|
-
|
3525
|
+
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
3526
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
3527
|
+
ggml_sycl_op_diag_mask_inf(ctx, dst);
|
3809
3528
|
}
|
3810
3529
|
|
3811
|
-
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx,
|
3812
|
-
|
3530
|
+
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
3531
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
3532
|
+
ggml_sycl_op_pool2d(ctx, dst);
|
3813
3533
|
}
|
3814
3534
|
|
3815
|
-
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx,
|
3816
|
-
|
3535
|
+
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
3536
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
3537
|
+
ggml_sycl_op_im2col(ctx, dst);
|
3817
3538
|
}
|
3818
3539
|
|
3819
|
-
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx,
|
3820
|
-
|
3821
|
-
|
3540
|
+
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
3541
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
3542
|
+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
3543
|
+
ggml_sycl_op_sum(ctx, dst);
|
3822
3544
|
}
|
3823
3545
|
|
3824
|
-
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx,
|
3825
|
-
|
3826
|
-
|
3546
|
+
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
3547
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
3548
|
+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
3549
|
+
ggml_sycl_op_sum_rows(ctx, dst);
|
3827
3550
|
}
|
3828
3551
|
|
3829
|
-
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx,
|
3830
|
-
|
3831
|
-
|
3552
|
+
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
3553
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
3554
|
+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
3555
|
+
ggml_sycl_op_argsort(ctx, dst);
|
3832
3556
|
}
|
3833
3557
|
|
3834
|
-
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx,
|
3835
|
-
|
3836
|
-
|
3558
|
+
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
3559
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
3560
|
+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
3561
|
+
ggml_sycl_op_argmax(ctx, dst);
|
3837
3562
|
}
|
3838
3563
|
|
3839
|
-
static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
3840
|
-
GGML_UNUSED(src0);
|
3841
|
-
GGML_UNUSED(src1);
|
3842
|
-
GGML_UNUSED(dst);
|
3843
|
-
GGML_UNUSED(ctx);
|
3844
|
-
}
|
3845
3564
|
|
3846
|
-
void ggml_sycl_set_main_device(const int main_device) try {
|
3565
|
+
static void ggml_sycl_set_main_device(const int main_device) try {
|
3847
3566
|
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
3848
3567
|
return;
|
3849
3568
|
}
|
@@ -3864,192 +3583,229 @@ catch (sycl::exception const &exc) {
|
|
3864
3583
|
std::exit(1);
|
3865
3584
|
}
|
3866
3585
|
|
3867
|
-
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor *
|
3586
|
+
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
|
3868
3587
|
if (!g_sycl_loaded) return false;
|
3869
3588
|
|
3870
|
-
|
3589
|
+
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
|
3590
|
+
ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
|
3591
|
+
}
|
3871
3592
|
|
3872
|
-
switch (
|
3593
|
+
switch (dst->op) {
|
3873
3594
|
case GGML_OP_ARGMAX:
|
3874
|
-
|
3595
|
+
ggml_sycl_argmax(ctx, dst);
|
3875
3596
|
break;
|
3876
3597
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
3877
|
-
|
3598
|
+
ggml_sycl_op_conv_transpose_1d(ctx, dst);
|
3878
3599
|
break;
|
3879
3600
|
case GGML_OP_REPEAT:
|
3880
|
-
|
3601
|
+
ggml_sycl_repeat(ctx, dst);
|
3881
3602
|
break;
|
3882
3603
|
case GGML_OP_GET_ROWS:
|
3883
|
-
|
3604
|
+
ggml_sycl_get_rows(ctx, dst);
|
3884
3605
|
break;
|
3885
3606
|
case GGML_OP_DUP:
|
3886
|
-
|
3607
|
+
ggml_sycl_dup(ctx, dst);
|
3887
3608
|
break;
|
3888
3609
|
case GGML_OP_ADD:
|
3889
3610
|
case GGML_OP_ADD1: // TODO: more efficient implementation
|
3890
|
-
|
3611
|
+
ggml_sycl_add(ctx, dst);
|
3891
3612
|
break;
|
3892
3613
|
case GGML_OP_SUB:
|
3893
|
-
|
3614
|
+
ggml_sycl_sub(ctx, dst);
|
3894
3615
|
break;
|
3895
3616
|
case GGML_OP_ACC:
|
3896
|
-
|
3617
|
+
ggml_sycl_acc(ctx, dst);
|
3897
3618
|
break;
|
3898
3619
|
case GGML_OP_MUL:
|
3899
|
-
|
3620
|
+
ggml_sycl_mul(ctx, dst);
|
3900
3621
|
break;
|
3901
3622
|
case GGML_OP_LOG:
|
3902
|
-
|
3623
|
+
ggml_sycl_log(ctx, dst);
|
3903
3624
|
break;
|
3904
3625
|
case GGML_OP_DIV:
|
3905
|
-
|
3626
|
+
ggml_sycl_div(ctx, dst);
|
3906
3627
|
break;
|
3907
3628
|
case GGML_OP_UNARY:
|
3908
|
-
switch (ggml_get_unary_op(
|
3629
|
+
switch (ggml_get_unary_op(dst)) {
|
3909
3630
|
case GGML_UNARY_OP_NEG:
|
3910
|
-
|
3631
|
+
ggml_sycl_neg(ctx, dst);
|
3911
3632
|
break;
|
3912
3633
|
case GGML_UNARY_OP_STEP:
|
3913
|
-
|
3634
|
+
ggml_sycl_step(ctx, dst);
|
3914
3635
|
break;
|
3915
3636
|
case GGML_UNARY_OP_GELU:
|
3916
|
-
|
3637
|
+
ggml_sycl_gelu(ctx, dst);
|
3917
3638
|
break;
|
3918
3639
|
case GGML_UNARY_OP_SILU:
|
3919
|
-
|
3640
|
+
ggml_sycl_silu(ctx, dst);
|
3920
3641
|
break;
|
3921
3642
|
case GGML_UNARY_OP_GELU_QUICK:
|
3922
|
-
|
3643
|
+
ggml_sycl_gelu_quick(ctx, dst);
|
3644
|
+
break;
|
3645
|
+
case GGML_UNARY_OP_GELU_ERF:
|
3646
|
+
ggml_sycl_gelu_erf(ctx, dst);
|
3923
3647
|
break;
|
3924
3648
|
case GGML_UNARY_OP_TANH:
|
3925
|
-
|
3649
|
+
ggml_sycl_tanh(ctx, dst);
|
3926
3650
|
break;
|
3927
3651
|
case GGML_UNARY_OP_RELU:
|
3928
|
-
|
3652
|
+
ggml_sycl_relu(ctx, dst);
|
3929
3653
|
break;
|
3930
3654
|
case GGML_UNARY_OP_SIGMOID:
|
3931
|
-
|
3655
|
+
ggml_sycl_sigmoid(ctx, dst);
|
3932
3656
|
break;
|
3933
3657
|
case GGML_UNARY_OP_HARDSIGMOID:
|
3934
|
-
|
3658
|
+
ggml_sycl_hardsigmoid(ctx, dst);
|
3935
3659
|
break;
|
3936
3660
|
case GGML_UNARY_OP_HARDSWISH:
|
3937
|
-
|
3661
|
+
ggml_sycl_hardswish(ctx, dst);
|
3938
3662
|
break;
|
3939
3663
|
case GGML_UNARY_OP_EXP:
|
3940
|
-
|
3664
|
+
ggml_sycl_exp(ctx, dst);
|
3665
|
+
break;
|
3666
|
+
case GGML_UNARY_OP_SGN:
|
3667
|
+
ggml_sycl_sgn(ctx, dst);
|
3668
|
+
break;
|
3669
|
+
case GGML_UNARY_OP_ABS:
|
3670
|
+
ggml_sycl_abs(ctx, dst);
|
3671
|
+
break;
|
3672
|
+
case GGML_UNARY_OP_ELU:
|
3673
|
+
ggml_sycl_elu(ctx, dst);
|
3674
|
+
break;
|
3675
|
+
default:
|
3676
|
+
return false;
|
3677
|
+
}
|
3678
|
+
break;
|
3679
|
+
case GGML_OP_GLU:
|
3680
|
+
switch (ggml_get_glu_op(dst)) {
|
3681
|
+
case GGML_GLU_OP_REGLU:
|
3682
|
+
ggml_sycl_reglu(ctx, dst);
|
3683
|
+
break;
|
3684
|
+
case GGML_GLU_OP_GEGLU:
|
3685
|
+
ggml_sycl_geglu(ctx, dst);
|
3686
|
+
break;
|
3687
|
+
case GGML_GLU_OP_SWIGLU:
|
3688
|
+
ggml_sycl_swiglu(ctx, dst);
|
3941
3689
|
break;
|
3942
3690
|
default:
|
3943
3691
|
return false;
|
3944
3692
|
}
|
3945
3693
|
break;
|
3946
3694
|
case GGML_OP_NORM:
|
3947
|
-
|
3695
|
+
ggml_sycl_norm(ctx, dst);
|
3948
3696
|
break;
|
3949
3697
|
case GGML_OP_GROUP_NORM:
|
3950
|
-
|
3698
|
+
ggml_sycl_group_norm(ctx, dst);
|
3951
3699
|
break;
|
3952
3700
|
case GGML_OP_CONCAT:
|
3953
|
-
|
3701
|
+
ggml_sycl_op_concat(ctx, dst);
|
3954
3702
|
break;
|
3955
3703
|
case GGML_OP_UPSCALE:
|
3956
|
-
|
3704
|
+
ggml_sycl_upscale(ctx, dst);
|
3957
3705
|
break;
|
3958
3706
|
case GGML_OP_PAD:
|
3959
|
-
|
3707
|
+
ggml_sycl_pad(ctx, dst);
|
3960
3708
|
break;
|
3961
3709
|
case GGML_OP_LEAKY_RELU:
|
3962
|
-
|
3710
|
+
ggml_sycl_leaky_relu(ctx, dst);
|
3963
3711
|
break;
|
3964
3712
|
case GGML_OP_RMS_NORM:
|
3965
|
-
|
3713
|
+
ggml_sycl_rms_norm(ctx, dst);
|
3714
|
+
break;
|
3715
|
+
case GGML_OP_L2_NORM:
|
3716
|
+
ggml_sycl_l2_norm(ctx, dst);
|
3966
3717
|
break;
|
3967
3718
|
case GGML_OP_MUL_MAT:
|
3968
|
-
if (
|
3719
|
+
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
3969
3720
|
return false;
|
3970
3721
|
}
|
3971
|
-
|
3722
|
+
/* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */
|
3723
|
+
ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
|
3972
3724
|
break;
|
3973
3725
|
case GGML_OP_MUL_MAT_ID:
|
3974
|
-
if (
|
3726
|
+
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
3975
3727
|
return false;
|
3976
3728
|
}
|
3977
|
-
|
3729
|
+
ggml_sycl_mul_mat_id(ctx, dst);
|
3978
3730
|
break;
|
3979
3731
|
case GGML_OP_OUT_PROD:
|
3980
|
-
|
3732
|
+
ggml_sycl_op_out_prod(ctx, dst);
|
3981
3733
|
break;
|
3982
3734
|
case GGML_OP_SCALE:
|
3983
|
-
|
3735
|
+
ggml_sycl_scale(ctx, dst);
|
3984
3736
|
break;
|
3985
3737
|
case GGML_OP_SQR:
|
3986
|
-
|
3738
|
+
ggml_sycl_sqr(ctx, dst);
|
3987
3739
|
break;
|
3988
3740
|
case GGML_OP_SQRT:
|
3989
|
-
|
3741
|
+
ggml_sycl_sqrt(ctx, dst);
|
3990
3742
|
break;
|
3991
3743
|
case GGML_OP_SIN:
|
3992
|
-
|
3744
|
+
ggml_sycl_sin(ctx, dst);
|
3993
3745
|
break;
|
3994
3746
|
case GGML_OP_COS:
|
3995
|
-
|
3747
|
+
ggml_sycl_cos(ctx, dst);
|
3996
3748
|
break;
|
3997
3749
|
case GGML_OP_CLAMP:
|
3998
|
-
|
3750
|
+
ggml_sycl_clamp(ctx, dst);
|
3999
3751
|
break;
|
4000
3752
|
case GGML_OP_CPY:
|
4001
|
-
|
3753
|
+
ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
|
4002
3754
|
break;
|
4003
3755
|
case GGML_OP_CONT:
|
4004
|
-
|
3756
|
+
ggml_sycl_dup(ctx, dst);
|
4005
3757
|
break;
|
4006
3758
|
case GGML_OP_NONE:
|
4007
3759
|
case GGML_OP_RESHAPE:
|
4008
3760
|
case GGML_OP_VIEW:
|
4009
3761
|
case GGML_OP_PERMUTE:
|
4010
3762
|
case GGML_OP_TRANSPOSE:
|
4011
|
-
|
3763
|
+
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
|
4012
3764
|
break;
|
4013
3765
|
case GGML_OP_DIAG_MASK_INF:
|
4014
|
-
|
3766
|
+
ggml_sycl_diag_mask_inf(ctx, dst);
|
4015
3767
|
break;
|
4016
3768
|
case GGML_OP_SOFT_MAX:
|
4017
|
-
|
3769
|
+
ggml_sycl_op_soft_max(ctx, dst);
|
4018
3770
|
break;
|
4019
3771
|
case GGML_OP_ROPE:
|
4020
|
-
|
3772
|
+
ggml_sycl_rope(ctx, dst);
|
4021
3773
|
break;
|
4022
3774
|
case GGML_OP_IM2COL:
|
4023
|
-
|
3775
|
+
ggml_sycl_im2col(ctx, dst);
|
4024
3776
|
break;
|
4025
3777
|
case GGML_OP_POOL_2D:
|
4026
|
-
|
3778
|
+
ggml_sycl_pool2d(ctx, dst);
|
4027
3779
|
break;
|
4028
3780
|
case GGML_OP_SUM:
|
4029
|
-
|
3781
|
+
ggml_sycl_sum(ctx, dst);
|
4030
3782
|
break;
|
4031
3783
|
case GGML_OP_SUM_ROWS:
|
4032
|
-
|
3784
|
+
ggml_sycl_sum_rows(ctx, dst);
|
4033
3785
|
break;
|
4034
3786
|
case GGML_OP_ARGSORT:
|
4035
|
-
|
3787
|
+
ggml_sycl_argsort(ctx, dst);
|
4036
3788
|
break;
|
4037
3789
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
4038
|
-
|
3790
|
+
ggml_sycl_op_timestep_embedding(ctx, dst);
|
4039
3791
|
break;
|
4040
3792
|
case GGML_OP_RWKV_WKV6:
|
4041
|
-
|
3793
|
+
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
3794
|
+
break;
|
3795
|
+
case GGML_OP_RWKV_WKV7:
|
3796
|
+
ggml_sycl_op_rwkv_wkv7(ctx, dst);
|
3797
|
+
break;
|
3798
|
+
case GGML_OP_GATED_LINEAR_ATTN:
|
3799
|
+
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
4042
3800
|
break;
|
4043
3801
|
default:
|
4044
3802
|
return false;
|
4045
3803
|
}
|
4046
3804
|
|
4047
|
-
if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
|
4048
|
-
ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
|
4049
|
-
}
|
4050
|
-
|
4051
|
-
func(ctx, tensor->src[0], tensor->src[1], tensor);
|
4052
3805
|
return true;
|
3806
|
+
} catch (sycl::exception & e) {
|
3807
|
+
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
3808
|
+
std::exit(1);
|
4053
3809
|
}
|
4054
3810
|
|
4055
3811
|
GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
|
@@ -4112,6 +3868,9 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
|
|
4112
3868
|
ggml_tensor *tensor,
|
4113
3869
|
const void *data, size_t offset,
|
4114
3870
|
size_t size) try {
|
3871
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
3872
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
|
3873
|
+
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
|
4115
3874
|
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
4116
3875
|
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
4117
3876
|
|
@@ -4130,13 +3889,16 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
|
|
4130
3889
|
const ggml_tensor *tensor,
|
4131
3890
|
void *data, size_t offset,
|
4132
3891
|
size_t size) try {
|
3892
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
3893
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
|
3894
|
+
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
|
4133
3895
|
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
4134
3896
|
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
4135
3897
|
|
4136
3898
|
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
|
4137
3899
|
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
|
4138
3900
|
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
|
4139
|
-
data, (const char *)tensor->data + offset, size)
|
3901
|
+
data, (const char *)tensor->data + offset, size)));
|
4140
3902
|
}
|
4141
3903
|
catch (sycl::exception const &exc) {
|
4142
3904
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
@@ -4148,7 +3910,13 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
|
|
4148
3910
|
const ggml_tensor *src,
|
4149
3911
|
ggml_tensor *dst) try {
|
4150
3912
|
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
4151
|
-
|
3913
|
+
bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) &&
|
3914
|
+
ggml_backend_buffer_is_sycl(src->buffer);
|
3915
|
+
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
|
3916
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
|
3917
|
+
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
|
3918
|
+
GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
|
3919
|
+
if (is_cpy_supported) {
|
4152
3920
|
/*
|
4153
3921
|
DPCT1009:215: SYCL uses exceptions to report errors and does not use the
|
4154
3922
|
error codes. The original code was commented out and a warning string
|
@@ -4156,7 +3924,7 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
|
|
4156
3924
|
*/
|
4157
3925
|
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
|
4158
3926
|
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
|
4159
|
-
dst->data, src->data, ggml_nbytes(dst))
|
3927
|
+
dst->data, src->data, ggml_nbytes(dst))));
|
4160
3928
|
return true;
|
4161
3929
|
}
|
4162
3930
|
|
@@ -4169,6 +3937,7 @@ catch (sycl::exception const &exc) {
|
|
4169
3937
|
}
|
4170
3938
|
|
4171
3939
|
static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
|
3940
|
+
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
|
4172
3941
|
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
4173
3942
|
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
|
4174
3943
|
SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
|
@@ -4181,11 +3950,9 @@ catch (sycl::exception const &exc) {
|
|
4181
3950
|
std::exit(1);
|
4182
3951
|
}
|
4183
3952
|
|
4184
|
-
static
|
4185
|
-
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
|
3953
|
+
static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
|
4186
3954
|
ggml_sycl_set_main_device(sycl_ctx->device);
|
4187
3955
|
|
4188
|
-
|
4189
3956
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
4190
3957
|
ggml_tensor * node = cgraph->nodes[i];
|
4191
3958
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
@@ -4205,7 +3972,82 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
|
4205
3972
|
}
|
4206
3973
|
GGML_ASSERT(ok);
|
4207
3974
|
}
|
3975
|
+
}
|
4208
3976
|
|
3977
|
+
#ifdef GGML_SYCL_GRAPH
|
3978
|
+
static bool check_graph_compatibility(ggml_cgraph * cgraph) {
|
3979
|
+
if (ggml_sycl_info().device_count > 1) {
|
3980
|
+
// A sycl_ex::command_graph object can only be created for a single device
|
3981
|
+
GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
|
3982
|
+
return false;
|
3983
|
+
}
|
3984
|
+
|
3985
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
3986
|
+
const ggml_op node_op = cgraph->nodes[i]->op;
|
3987
|
+
switch (node_op) {
|
3988
|
+
default:
|
3989
|
+
break;
|
3990
|
+
case GGML_OP_CONCAT:
|
3991
|
+
// ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
|
3992
|
+
// but wait() can't be called on the events returned by a queue recording
|
3993
|
+
// to a graph.
|
3994
|
+
[[fallthrough]];
|
3995
|
+
case GGML_OP_MUL_MAT_ID:
|
3996
|
+
// ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
|
3997
|
+
// submitting a memcpy operation, but wait() can't be called on a queue that
|
3998
|
+
// is recording to a graph.
|
3999
|
+
GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
|
4000
|
+
ggml_op_name(node_op));
|
4001
|
+
return false;
|
4002
|
+
}
|
4003
|
+
}
|
4004
|
+
return true;
|
4005
|
+
}
|
4006
|
+
#endif
|
4007
|
+
|
4008
|
+
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
4009
|
+
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
|
4010
|
+
|
4011
|
+
#ifdef GGML_SYCL_GRAPH
|
4012
|
+
bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
|
4013
|
+
if (use_sycl_graph) {
|
4014
|
+
const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
|
4015
|
+
if (!graph_support) {
|
4016
|
+
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
|
4017
|
+
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
4018
|
+
return GGML_STATUS_SUCCESS;
|
4019
|
+
}
|
4020
|
+
|
4021
|
+
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
|
4022
|
+
|
4023
|
+
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
4024
|
+
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
4025
|
+
model_sycl_graph.end_recording();
|
4026
|
+
|
4027
|
+
const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
|
4028
|
+
if (!sycl_ctx->exec_graph || !graph_update_support) {
|
4029
|
+
auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
|
4030
|
+
model_sycl_graph.finalize();
|
4031
|
+
sycl_ctx->exec_graph = std::make_unique<
|
4032
|
+
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
4033
|
+
} else {
|
4034
|
+
try {
|
4035
|
+
sycl_ctx->exec_graph->update(model_sycl_graph);
|
4036
|
+
GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
|
4037
|
+
} catch (sycl::exception const & e) {
|
4038
|
+
GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
|
4039
|
+
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
|
4040
|
+
sycl_ctx->exec_graph = std::make_unique<
|
4041
|
+
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
|
4042
|
+
}
|
4043
|
+
}
|
4044
|
+
|
4045
|
+
sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
|
4046
|
+
} else
|
4047
|
+
#endif
|
4048
|
+
{
|
4049
|
+
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
4050
|
+
}
|
4209
4051
|
return GGML_STATUS_SUCCESS;
|
4210
4052
|
}
|
4211
4053
|
|
@@ -4229,7 +4071,7 @@ catch (sycl::exception const &exc)
|
|
4229
4071
|
}
|
4230
4072
|
|
4231
4073
|
static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
|
4232
|
-
|
4074
|
+
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
|
4233
4075
|
sycl::event* sycl_event = static_cast<sycl::event*>(event->context);
|
4234
4076
|
|
4235
4077
|
if (ggml_backend_is_sycl(backend)) {
|
@@ -4270,7 +4112,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
|
|
4270
4112
|
}
|
4271
4113
|
|
4272
4114
|
int ggml_backend_sycl_get_device_count() {
|
4273
|
-
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
|
4274
4115
|
return ggml_sycl_info().device_count;
|
4275
4116
|
}
|
4276
4117
|
|
@@ -4360,7 +4201,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4360
4201
|
return true;
|
4361
4202
|
}
|
4362
4203
|
return false;
|
4363
|
-
}
|
4204
|
+
}
|
4364
4205
|
case GGML_OP_UNARY:
|
4365
4206
|
switch (ggml_get_unary_op(op)) {
|
4366
4207
|
case GGML_UNARY_OP_NEG:
|
@@ -4372,9 +4213,26 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4372
4213
|
case GGML_UNARY_OP_HARDSIGMOID:
|
4373
4214
|
case GGML_UNARY_OP_HARDSWISH:
|
4374
4215
|
case GGML_UNARY_OP_GELU_QUICK:
|
4216
|
+
case GGML_UNARY_OP_GELU_ERF:
|
4375
4217
|
case GGML_UNARY_OP_TANH:
|
4376
4218
|
case GGML_UNARY_OP_EXP:
|
4377
|
-
|
4219
|
+
case GGML_UNARY_OP_SGN:
|
4220
|
+
case GGML_UNARY_OP_ABS:
|
4221
|
+
case GGML_UNARY_OP_ELU:
|
4222
|
+
#if defined (GGML_SYCL_F16)
|
4223
|
+
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
|
4224
|
+
#else
|
4225
|
+
return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
4226
|
+
#endif
|
4227
|
+
default:
|
4228
|
+
return false;
|
4229
|
+
}
|
4230
|
+
case GGML_OP_GLU:
|
4231
|
+
switch (ggml_get_glu_op(op)) {
|
4232
|
+
case GGML_GLU_OP_REGLU:
|
4233
|
+
case GGML_GLU_OP_GEGLU:
|
4234
|
+
case GGML_GLU_OP_SWIGLU:
|
4235
|
+
return ggml_is_contiguous_1(op->src[0]);
|
4378
4236
|
default:
|
4379
4237
|
return false;
|
4380
4238
|
}
|
@@ -4409,7 +4267,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4409
4267
|
return false;
|
4410
4268
|
}
|
4411
4269
|
return true;
|
4412
|
-
}
|
4270
|
+
}
|
4413
4271
|
case GGML_OP_OUT_PROD:
|
4414
4272
|
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
4415
4273
|
case GGML_OP_GET_ROWS:
|
@@ -4426,11 +4284,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4426
4284
|
default:
|
4427
4285
|
return false;
|
4428
4286
|
}
|
4429
|
-
}
|
4287
|
+
}
|
4430
4288
|
case GGML_OP_CPY:
|
4431
4289
|
{
|
4432
4290
|
ggml_type src0_type = op->src[0]->type;
|
4433
4291
|
ggml_type src1_type = op->src[1]->type;
|
4292
|
+
if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {
|
4293
|
+
return true;
|
4294
|
+
}
|
4434
4295
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
4435
4296
|
return true;
|
4436
4297
|
}
|
@@ -4452,35 +4313,85 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4452
4313
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
4453
4314
|
return true;
|
4454
4315
|
}
|
4316
|
+
if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
|
4317
|
+
return true;
|
4318
|
+
}
|
4319
|
+
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
|
4320
|
+
return true;
|
4321
|
+
}
|
4322
|
+
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
|
4323
|
+
return true;
|
4324
|
+
}
|
4325
|
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
4326
|
+
return true;
|
4327
|
+
}
|
4328
|
+
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
|
4329
|
+
return true;
|
4330
|
+
}
|
4331
|
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
4332
|
+
return true;
|
4333
|
+
}
|
4334
|
+
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
|
4335
|
+
return true;
|
4336
|
+
}
|
4337
|
+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
4338
|
+
return true;
|
4339
|
+
}
|
4340
|
+
if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {
|
4341
|
+
return true;
|
4342
|
+
}
|
4343
|
+
if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {
|
4344
|
+
return true;
|
4345
|
+
}
|
4346
|
+
if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {
|
4347
|
+
return true;
|
4348
|
+
}
|
4349
|
+
if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {
|
4350
|
+
return true;
|
4351
|
+
}
|
4352
|
+
if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {
|
4353
|
+
return true;
|
4354
|
+
}
|
4455
4355
|
return false;
|
4456
|
-
}
|
4356
|
+
}
|
4457
4357
|
case GGML_OP_CONCAT:
|
4458
4358
|
{
|
4459
4359
|
ggml_type src0_type = op->src[0]->type;
|
4460
4360
|
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
4461
|
-
}
|
4361
|
+
}
|
4462
4362
|
case GGML_OP_DUP:
|
4463
4363
|
case GGML_OP_ARGMAX:
|
4464
4364
|
case GGML_OP_NONE:
|
4465
4365
|
case GGML_OP_RESHAPE:
|
4466
|
-
case GGML_OP_REPEAT:
|
4467
4366
|
case GGML_OP_VIEW:
|
4468
4367
|
case GGML_OP_PERMUTE:
|
4469
4368
|
case GGML_OP_TRANSPOSE:
|
4470
|
-
|
4369
|
+
return true;
|
4471
4370
|
case GGML_OP_ADD:
|
4472
4371
|
case GGML_OP_ADD1:
|
4473
|
-
case GGML_OP_LOG:
|
4474
4372
|
case GGML_OP_SUB:
|
4475
4373
|
case GGML_OP_MUL:
|
4476
4374
|
case GGML_OP_DIV:
|
4477
|
-
case
|
4478
|
-
|
4375
|
+
case GGML_OP_REPEAT:
|
4376
|
+
return true;
|
4479
4377
|
case GGML_OP_SQR:
|
4480
4378
|
case GGML_OP_SQRT:
|
4481
4379
|
case GGML_OP_SIN:
|
4482
4380
|
case GGML_OP_COS:
|
4483
4381
|
case GGML_OP_CLAMP:
|
4382
|
+
case GGML_OP_LOG:
|
4383
|
+
#if defined (GGML_SYCL_F16)
|
4384
|
+
return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
|
4385
|
+
#else
|
4386
|
+
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
4387
|
+
#endif
|
4388
|
+
case GGML_OP_NORM:
|
4389
|
+
case GGML_OP_RMS_NORM:
|
4390
|
+
return true;
|
4391
|
+
case GGML_OP_L2_NORM:
|
4392
|
+
case GGML_OP_GROUP_NORM:
|
4393
|
+
return ggml_is_contiguous(op->src[0]);
|
4394
|
+
case GGML_OP_SCALE:
|
4484
4395
|
return true;
|
4485
4396
|
case GGML_OP_CONT:
|
4486
4397
|
return op->src[0]->type != GGML_TYPE_BF16;
|
@@ -4488,30 +4399,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4488
4399
|
case GGML_OP_SOFT_MAX:
|
4489
4400
|
return true;
|
4490
4401
|
case GGML_OP_ROPE:
|
4491
|
-
{
|
4492
|
-
const int mode = ((const int32_t *) op->op_params)[2];
|
4493
|
-
if (mode & GGML_ROPE_TYPE_MROPE) {
|
4494
|
-
return false;
|
4495
|
-
}
|
4496
|
-
if (mode & GGML_ROPE_TYPE_VISION) {
|
4497
|
-
return false;
|
4498
|
-
}
|
4499
|
-
return ggml_is_contiguous(op->src[0]);
|
4500
|
-
}
|
4501
4402
|
case GGML_OP_IM2COL:
|
4502
|
-
|
4503
|
-
|
4403
|
+
return true;
|
4404
|
+
case GGML_OP_UPSCALE:
|
4405
|
+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
4504
4406
|
case GGML_OP_POOL_2D:
|
4505
4407
|
case GGML_OP_SUM:
|
4506
4408
|
case GGML_OP_SUM_ROWS:
|
4507
4409
|
case GGML_OP_ARGSORT:
|
4508
4410
|
case GGML_OP_ACC:
|
4509
|
-
case GGML_OP_GROUP_NORM:
|
4510
|
-
case GGML_OP_UPSCALE:
|
4511
4411
|
case GGML_OP_PAD:
|
4512
4412
|
case GGML_OP_LEAKY_RELU:
|
4513
4413
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
4514
4414
|
case GGML_OP_RWKV_WKV6:
|
4415
|
+
case GGML_OP_RWKV_WKV7:
|
4416
|
+
case GGML_OP_GATED_LINEAR_ATTN:
|
4515
4417
|
return true;
|
4516
4418
|
default:
|
4517
4419
|
return false;
|
@@ -4586,6 +4488,7 @@ static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_bac
|
|
4586
4488
|
|
4587
4489
|
static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
|
4588
4490
|
GGML_UNUSED(dev);
|
4491
|
+
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
|
4589
4492
|
|
4590
4493
|
sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
|
4591
4494
|
SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
|
@@ -4638,10 +4541,9 @@ static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t re
|
|
4638
4541
|
static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
|
4639
4542
|
GGML_UNUSED(reg);
|
4640
4543
|
|
4641
|
-
|
4642
|
-
|
4643
|
-
|
4644
|
-
//}
|
4544
|
+
if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
|
4545
|
+
return (void *)ggml_backend_sycl_split_buffer_type;
|
4546
|
+
}
|
4645
4547
|
|
4646
4548
|
// SYCL doesn't support registering host memory, left here for reference
|
4647
4549
|
// "ggml_backend_register_host_buffer"
|