whispercpp 1.3.1 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +4 -3
- data/README.md +92 -31
- data/Rakefile +26 -7
- data/ext/.gitignore +5 -7
- data/ext/dependencies.rb +61 -0
- data/ext/extconf.rb +21 -198
- data/ext/options.rb +221 -0
- data/ext/ruby_whisper.c +159 -0
- data/ext/ruby_whisper.h +17 -2
- data/ext/ruby_whisper_context.c +641 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1301 -0
- data/ext/ruby_whisper_segment.c +143 -0
- data/ext/ruby_whisper_transcribe.cpp +87 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/.dockerignore +3 -0
- data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +251 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/ci/run.sh +336 -0
- data/ext/sources/close-issue.yml +28 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
- data/ext/sources/examples/addon.node/addon.cpp +438 -0
- data/ext/sources/examples/addon.node/index.js +54 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +175 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1294 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +776 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +168 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +467 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +223 -0
- data/ext/sources/examples/server/CMakeLists.txt +12 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1091 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +429 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
- data/ext/sources/examples/talk-llama/llama-context.h +276 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
- data/ext/sources/examples/talk-llama/llama-model.h +425 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
- data/ext/sources/examples/talk-llama/llama.cpp +354 -0
- data/ext/sources/examples/talk-llama/llama.h +1377 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +390 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +26 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
- data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
- data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
- data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
- data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
- data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
- data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +346 -0
- data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
- data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
- data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
- data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
- data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
- data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
- data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
- data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
- data/ext/sources/ggml/src/gguf.cpp +1330 -0
- data/ext/{include → sources/include}/whisper.h +68 -2
- data/ext/sources/src/CMakeLists.txt +143 -0
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{src → sources/src}/whisper.cpp +1905 -374
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +33 -5
- data/lib/whisper/model/uri.rb +149 -128
- data/sig/whisper.rbs +480 -0
- data/tests/helper.rb +28 -0
- data/tests/test_callback.rb +45 -3
- data/tests/test_error.rb +2 -2
- data/tests/test_model.rb +38 -0
- data/tests/test_package.rb +18 -3
- data/tests/test_params.rb +145 -8
- data/tests/test_segment.rb +10 -19
- data/tests/test_vad.rb +19 -0
- data/tests/test_vad_params.rb +103 -0
- data/tests/test_whisper.rb +37 -37
- data/whispercpp.gemspec +5 -4
- metadata +766 -111
- data/ext/cpu.mk +0 -9
- data/ext/examples/dr_wav.h +0 -8815
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
- data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
- data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
- data/ext/metal-embed.mk +0 -17
- data/ext/metal.mk +0 -6
- data/ext/ruby_whisper.cpp +0 -1909
- data/ext/scripts/get-flags.mk +0 -38
- data/lib/whisper.rb +0 -2
- /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
- /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -0,0 +1,2957 @@
|
|
1
|
+
//
|
2
|
+
// MIT license
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
4
|
+
// SPDX-License-Identifier: MIT
|
5
|
+
//
|
6
|
+
|
7
|
+
//
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
11
|
+
//
|
12
|
+
|
13
|
+
#ifndef GGML_SYCL_DPCT_HELPER_HPP
|
14
|
+
#define GGML_SYCL_DPCT_HELPER_HPP
|
15
|
+
|
16
|
+
#include <sycl/sycl.hpp>
|
17
|
+
#include <sycl/half_type.hpp>
|
18
|
+
#include <syclcompat/math.hpp>
|
19
|
+
#include <map>
|
20
|
+
|
21
|
+
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
22
|
+
#include <oneapi/mkl.hpp>
|
23
|
+
// Allow to use the same namespace for Intel oneMKL and oneMath
|
24
|
+
namespace oneapi {
|
25
|
+
namespace math = mkl;
|
26
|
+
}
|
27
|
+
#else
|
28
|
+
#include <oneapi/math.hpp>
|
29
|
+
#endif
|
30
|
+
|
31
|
+
#include "ggml.h"
|
32
|
+
|
33
|
+
#if defined(__linux__)
|
34
|
+
#include <sys/mman.h>
|
35
|
+
#elif defined(_WIN64)
|
36
|
+
#ifndef NOMINMAX
|
37
|
+
#define NOMINMAX
|
38
|
+
#endif
|
39
|
+
#include <windows.h>
|
40
|
+
#else
|
41
|
+
#error "Only support Windows and Linux."
|
42
|
+
#endif
|
43
|
+
|
44
|
+
#if defined(__linux__)
|
45
|
+
#include <unistd.h>
|
46
|
+
#include <sys/syscall.h>
|
47
|
+
#endif
|
48
|
+
#if defined(_WIN64)
|
49
|
+
#ifndef NOMINMAX
|
50
|
+
#define NOMINMAX
|
51
|
+
#endif
|
52
|
+
#include <windows.h>
|
53
|
+
#endif
|
54
|
+
|
55
|
+
#define DPCT_COMPATIBILITY_TEMP (900)
|
56
|
+
|
57
|
+
#if defined(_MSC_VER)
|
58
|
+
#define __dpct_align__(n) __declspec(align(n))
|
59
|
+
#define __dpct_inline__ __forceinline
|
60
|
+
#else
|
61
|
+
#define __dpct_align__(n) __attribute__((aligned(n)))
|
62
|
+
#define __dpct_inline__ __inline__ __attribute__((always_inline))
|
63
|
+
#endif
|
64
|
+
|
65
|
+
#if defined(_MSC_VER)
|
66
|
+
#define __dpct_noinline__ __declspec(noinline)
|
67
|
+
#else
|
68
|
+
#define __dpct_noinline__ __attribute__((noinline))
|
69
|
+
#endif
|
70
|
+
|
71
|
+
inline std::string get_device_type_name(const sycl::device &Device) {
|
72
|
+
auto DeviceType = Device.get_info<sycl::info::device::device_type>();
|
73
|
+
switch (DeviceType) {
|
74
|
+
case sycl::info::device_type::cpu:
|
75
|
+
return "cpu";
|
76
|
+
case sycl::info::device_type::gpu:
|
77
|
+
return "gpu";
|
78
|
+
case sycl::info::device_type::host:
|
79
|
+
return "host";
|
80
|
+
case sycl::info::device_type::accelerator:
|
81
|
+
return "acc";
|
82
|
+
default:
|
83
|
+
return "unknown";
|
84
|
+
}
|
85
|
+
}
|
86
|
+
|
87
|
+
inline std::string get_device_backend_and_type(const sycl::device &device) {
|
88
|
+
std::stringstream device_type;
|
89
|
+
sycl::backend backend = device.get_backend();
|
90
|
+
device_type << backend << ":" << get_device_type_name(device);
|
91
|
+
return device_type.str();
|
92
|
+
}
|
93
|
+
|
94
|
+
template <typename Ts> struct matrix_info_t {
|
95
|
+
oneapi::math::transpose transpose_info[2];
|
96
|
+
Ts value_info[2];
|
97
|
+
std::int64_t size_info[3];
|
98
|
+
std::int64_t ld_info[3];
|
99
|
+
std::int64_t groupsize_info;
|
100
|
+
};
|
101
|
+
|
102
|
+
inline auto get_onemath_backend(sycl::queue& queue)
|
103
|
+
#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
104
|
+
-> sycl::queue&
|
105
|
+
#endif
|
106
|
+
{
|
107
|
+
// If the backend is known at compile-time, use oneMath backend_selector to use
|
108
|
+
// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
|
109
|
+
// fallback to runtime dispatching.
|
110
|
+
#if defined(GGML_SYCL_NVIDIA)
|
111
|
+
return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
|
112
|
+
#elif defined(GGML_SYCL_AMD)
|
113
|
+
return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
|
114
|
+
#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
115
|
+
return queue;
|
116
|
+
#else
|
117
|
+
static_assert(false, "Unsupported backend");
|
118
|
+
#endif
|
119
|
+
}
|
120
|
+
|
121
|
+
namespace dpct
|
122
|
+
{
|
123
|
+
typedef sycl::queue *queue_ptr;
|
124
|
+
typedef sycl::event *event_ptr;
|
125
|
+
typedef char *device_ptr;
|
126
|
+
typedef uint8_t byte_t;
|
127
|
+
typedef sycl::buffer<byte_t> buffer_t;
|
128
|
+
|
129
|
+
/// SYCL default exception handler
|
130
|
+
inline auto exception_handler = [](sycl::exception_list exceptions)
|
131
|
+
{
|
132
|
+
for (std::exception_ptr const &e : exceptions)
|
133
|
+
{
|
134
|
+
try
|
135
|
+
{
|
136
|
+
std::rethrow_exception(e);
|
137
|
+
}
|
138
|
+
catch (sycl::exception const &e)
|
139
|
+
{
|
140
|
+
std::cerr << "Caught asynchronous SYCL exception:" << std::endl
|
141
|
+
<< e.what() << std::endl
|
142
|
+
<< "Exception caught at file:" << __FILE__
|
143
|
+
<< ", line:" << __LINE__ << std::endl;
|
144
|
+
}
|
145
|
+
}
|
146
|
+
};
|
147
|
+
|
148
|
+
enum error_code
|
149
|
+
{
|
150
|
+
success = 0,
|
151
|
+
default_error = 999
|
152
|
+
};
|
153
|
+
|
154
|
+
enum memcpy_direction
|
155
|
+
{
|
156
|
+
host_to_host,
|
157
|
+
host_to_device,
|
158
|
+
device_to_host,
|
159
|
+
device_to_device,
|
160
|
+
automatic
|
161
|
+
};
|
162
|
+
|
163
|
+
enum memory_region
|
164
|
+
{
|
165
|
+
global = 0, // device global memory
|
166
|
+
constant, // device constant memory
|
167
|
+
local, // device local memory
|
168
|
+
shared, // memory which can be accessed by host and device
|
169
|
+
};
|
170
|
+
|
171
|
+
enum class library_data_t : unsigned char
|
172
|
+
{
|
173
|
+
real_float = 0,
|
174
|
+
complex_float,
|
175
|
+
real_double,
|
176
|
+
complex_double,
|
177
|
+
real_half,
|
178
|
+
complex_half,
|
179
|
+
real_bfloat16,
|
180
|
+
complex_bfloat16,
|
181
|
+
real_int4,
|
182
|
+
complex_int4,
|
183
|
+
real_uint4,
|
184
|
+
complex_uint4,
|
185
|
+
real_int8,
|
186
|
+
complex_int8,
|
187
|
+
real_uint8,
|
188
|
+
complex_uint8,
|
189
|
+
real_int16,
|
190
|
+
complex_int16,
|
191
|
+
real_uint16,
|
192
|
+
complex_uint16,
|
193
|
+
real_int32,
|
194
|
+
complex_int32,
|
195
|
+
real_uint32,
|
196
|
+
complex_uint32,
|
197
|
+
real_int64,
|
198
|
+
complex_int64,
|
199
|
+
real_uint64,
|
200
|
+
complex_uint64,
|
201
|
+
real_int8_4,
|
202
|
+
real_int8_32,
|
203
|
+
real_uint8_4,
|
204
|
+
library_data_t_size
|
205
|
+
};
|
206
|
+
|
207
|
+
template <typename T>
|
208
|
+
struct DataType
|
209
|
+
{
|
210
|
+
using T2 = T;
|
211
|
+
};
|
212
|
+
template <typename T>
|
213
|
+
struct DataType<sycl::vec<T, 2>>
|
214
|
+
{
|
215
|
+
using T2 = std::complex<T>;
|
216
|
+
};
|
217
|
+
|
218
|
+
static void destroy_event(event_ptr event)
|
219
|
+
{
|
220
|
+
delete event;
|
221
|
+
}
|
222
|
+
|
223
|
+
static inline unsigned int get_tid()
|
224
|
+
{
|
225
|
+
#if defined(__linux__)
|
226
|
+
return syscall(SYS_gettid);
|
227
|
+
#elif defined(_WIN64)
|
228
|
+
return GetCurrentThreadId();
|
229
|
+
#else
|
230
|
+
#error "Only support Windows and Linux."
|
231
|
+
#endif
|
232
|
+
}
|
233
|
+
|
234
|
+
namespace detail
|
235
|
+
{
|
236
|
+
static void get_version(const sycl::device &dev, int &major, int &minor)
|
237
|
+
{
|
238
|
+
// Version string has the following format:
|
239
|
+
// a. OpenCL<space><major.minor><space><vendor-specific-information>
|
240
|
+
// b. <major.minor>
|
241
|
+
// c. <AmdGcnArchName> e.g gfx1030
|
242
|
+
std::string ver;
|
243
|
+
ver = dev.get_info<sycl::info::device::version>();
|
244
|
+
std::string::size_type i = 0;
|
245
|
+
while (i < ver.size()) {
|
246
|
+
if (isdigit(ver[i]))
|
247
|
+
break;
|
248
|
+
i++;
|
249
|
+
}
|
250
|
+
major = std::stoi(&(ver[i]));
|
251
|
+
while (i < ver.size()) {
|
252
|
+
if (ver[i] == '.')
|
253
|
+
break;
|
254
|
+
i++;
|
255
|
+
}
|
256
|
+
if (i < ver.size()) {
|
257
|
+
// a. and b.
|
258
|
+
i++;
|
259
|
+
minor = std::stoi(&(ver[i]));
|
260
|
+
} else {
|
261
|
+
// c.
|
262
|
+
minor = 0;
|
263
|
+
}
|
264
|
+
}
|
265
|
+
|
266
|
+
template <typename tag, typename T>
|
267
|
+
class generic_error_type
|
268
|
+
{
|
269
|
+
public:
|
270
|
+
generic_error_type() = default;
|
271
|
+
generic_error_type(T value) : value{value} {}
|
272
|
+
operator T() const { return value; }
|
273
|
+
|
274
|
+
private:
|
275
|
+
T value;
|
276
|
+
};
|
277
|
+
|
278
|
+
} // namespace detail
|
279
|
+
|
280
|
+
/// Pitched 2D/3D memory data.
|
281
|
+
class pitched_data
|
282
|
+
{
|
283
|
+
public:
|
284
|
+
pitched_data() : pitched_data(nullptr, 0, 0, 0) {}
|
285
|
+
pitched_data(void *data, size_t pitch, size_t x, size_t y)
|
286
|
+
: _data(data), _pitch(pitch), _x(x), _y(y) {}
|
287
|
+
|
288
|
+
void *get_data_ptr() { return _data; }
|
289
|
+
void set_data_ptr(void *data) { _data = data; }
|
290
|
+
|
291
|
+
size_t get_pitch() { return _pitch; }
|
292
|
+
void set_pitch(size_t pitch) { _pitch = pitch; }
|
293
|
+
|
294
|
+
size_t get_x() { return _x; }
|
295
|
+
void set_x(size_t x) { _x = x; }
|
296
|
+
|
297
|
+
size_t get_y() { return _y; }
|
298
|
+
void set_y(size_t y) { _y = y; }
|
299
|
+
|
300
|
+
private:
|
301
|
+
void *_data;
|
302
|
+
size_t _pitch, _x, _y;
|
303
|
+
};
|
304
|
+
|
305
|
+
class device_info
|
306
|
+
{
|
307
|
+
public:
|
308
|
+
// get interface
|
309
|
+
const char *get_name() const { return _name; }
|
310
|
+
char *get_name() { return _name; }
|
311
|
+
template <typename WorkItemSizesTy = sycl::range<3>,
|
312
|
+
std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
|
313
|
+
std::is_same_v<WorkItemSizesTy, int *>,
|
314
|
+
int> = 0>
|
315
|
+
auto get_max_work_item_sizes() const
|
316
|
+
{
|
317
|
+
if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
|
318
|
+
return sycl::range<3>(_max_work_item_sizes_i[0],
|
319
|
+
_max_work_item_sizes_i[1],
|
320
|
+
_max_work_item_sizes_i[2]);
|
321
|
+
else
|
322
|
+
{
|
323
|
+
return _max_work_item_sizes_i;
|
324
|
+
}
|
325
|
+
}
|
326
|
+
template <typename WorkItemSizesTy = sycl::range<3>,
|
327
|
+
std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
|
328
|
+
std::is_same_v<WorkItemSizesTy, int *>,
|
329
|
+
int> = 0>
|
330
|
+
auto get_max_work_item_sizes()
|
331
|
+
{
|
332
|
+
if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
|
333
|
+
return sycl::range<3>(_max_work_item_sizes_i[0],
|
334
|
+
_max_work_item_sizes_i[1],
|
335
|
+
_max_work_item_sizes_i[2]);
|
336
|
+
else
|
337
|
+
{
|
338
|
+
return _max_work_item_sizes_i;
|
339
|
+
}
|
340
|
+
}
|
341
|
+
bool get_host_unified_memory() const { return _host_unified_memory; }
|
342
|
+
int get_major_version() const { return _major; }
|
343
|
+
int get_minor_version() const { return _minor; }
|
344
|
+
int get_integrated() const { return _integrated; }
|
345
|
+
int get_max_clock_frequency() const { return _frequency; }
|
346
|
+
int get_max_compute_units() const { return _max_compute_units; }
|
347
|
+
int get_max_work_group_size() const { return _max_work_group_size; }
|
348
|
+
int get_max_sub_group_size() const { return _max_sub_group_size; }
|
349
|
+
int get_max_work_items_per_compute_unit() const
|
350
|
+
{
|
351
|
+
return _max_work_items_per_compute_unit;
|
352
|
+
}
|
353
|
+
int get_max_register_size_per_work_group() const
|
354
|
+
{
|
355
|
+
return _max_register_size_per_work_group;
|
356
|
+
}
|
357
|
+
template <typename NDRangeSizeTy = size_t *,
|
358
|
+
std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
|
359
|
+
std::is_same_v<NDRangeSizeTy, int *>,
|
360
|
+
int> = 0>
|
361
|
+
auto get_max_nd_range_size() const
|
362
|
+
{
|
363
|
+
if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
|
364
|
+
return _max_nd_range_size;
|
365
|
+
else
|
366
|
+
return _max_nd_range_size_i;
|
367
|
+
}
|
368
|
+
template <typename NDRangeSizeTy = size_t *,
|
369
|
+
std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
|
370
|
+
std::is_same_v<NDRangeSizeTy, int *>,
|
371
|
+
int> = 0>
|
372
|
+
auto get_max_nd_range_size()
|
373
|
+
{
|
374
|
+
if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
|
375
|
+
return _max_nd_range_size;
|
376
|
+
else
|
377
|
+
return _max_nd_range_size_i;
|
378
|
+
}
|
379
|
+
size_t get_global_mem_size() const { return _global_mem_size; }
|
380
|
+
size_t get_local_mem_size() const { return _local_mem_size; }
|
381
|
+
size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
|
382
|
+
/// Returns the maximum clock rate of device's global memory in kHz. If
|
383
|
+
/// compiler does not support this API then returns default value 3200000 kHz.
|
384
|
+
unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
|
385
|
+
/// Returns the maximum bus width between device and memory in bits. If
|
386
|
+
/// compiler does not support this API then returns default value 64 bits.
|
387
|
+
unsigned int get_memory_bus_width() const { return _memory_bus_width; }
|
388
|
+
uint32_t get_device_id() const { return _device_id; }
|
389
|
+
std::array<unsigned char, 16> get_uuid() const { return _uuid; }
|
390
|
+
/// Returns global memory cache size in bytes.
|
391
|
+
unsigned int get_global_mem_cache_size() const
|
392
|
+
{
|
393
|
+
return _global_mem_cache_size;
|
394
|
+
}
|
395
|
+
|
396
|
+
// set interface
|
397
|
+
void set_name(const char *name)
|
398
|
+
{
|
399
|
+
size_t length = strlen(name);
|
400
|
+
if (length < 256)
|
401
|
+
{
|
402
|
+
std::memcpy(_name, name, length + 1);
|
403
|
+
}
|
404
|
+
else
|
405
|
+
{
|
406
|
+
std::memcpy(_name, name, 255);
|
407
|
+
_name[255] = '\0';
|
408
|
+
}
|
409
|
+
}
|
410
|
+
void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes)
|
411
|
+
{
|
412
|
+
for (int i = 0; i < 3; ++i)
|
413
|
+
_max_work_item_sizes_i[i] = max_work_item_sizes[i];
|
414
|
+
}
|
415
|
+
[[deprecated]] void
|
416
|
+
set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes)
|
417
|
+
{
|
418
|
+
for (int i = 0; i < 3; ++i)
|
419
|
+
{
|
420
|
+
_max_work_item_sizes_i[i] = max_work_item_sizes[i];
|
421
|
+
}
|
422
|
+
}
|
423
|
+
void set_host_unified_memory(bool host_unified_memory)
|
424
|
+
{
|
425
|
+
_host_unified_memory = host_unified_memory;
|
426
|
+
}
|
427
|
+
void set_major_version(int major) { _major = major; }
|
428
|
+
void set_minor_version(int minor) { _minor = minor; }
|
429
|
+
void set_integrated(int integrated) { _integrated = integrated; }
|
430
|
+
void set_max_clock_frequency(int frequency) { _frequency = frequency; }
|
431
|
+
void set_max_compute_units(int max_compute_units)
|
432
|
+
{
|
433
|
+
_max_compute_units = max_compute_units;
|
434
|
+
}
|
435
|
+
void set_global_mem_size(size_t global_mem_size)
|
436
|
+
{
|
437
|
+
_global_mem_size = global_mem_size;
|
438
|
+
}
|
439
|
+
void set_local_mem_size(size_t local_mem_size)
|
440
|
+
{
|
441
|
+
_local_mem_size = local_mem_size;
|
442
|
+
}
|
443
|
+
void set_max_mem_alloc_size(size_t max_mem_alloc_size)
|
444
|
+
{
|
445
|
+
_max_mem_alloc_size = max_mem_alloc_size;
|
446
|
+
}
|
447
|
+
void set_max_work_group_size(int max_work_group_size)
|
448
|
+
{
|
449
|
+
_max_work_group_size = max_work_group_size;
|
450
|
+
}
|
451
|
+
void set_max_sub_group_size(int max_sub_group_size)
|
452
|
+
{
|
453
|
+
_max_sub_group_size = max_sub_group_size;
|
454
|
+
}
|
455
|
+
void
|
456
|
+
set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit)
|
457
|
+
{
|
458
|
+
_max_work_items_per_compute_unit = max_work_items_per_compute_unit;
|
459
|
+
}
|
460
|
+
void set_max_nd_range_size(int max_nd_range_size[])
|
461
|
+
{
|
462
|
+
for (int i = 0; i < 3; i++)
|
463
|
+
{
|
464
|
+
_max_nd_range_size[i] = max_nd_range_size[i];
|
465
|
+
_max_nd_range_size_i[i] = max_nd_range_size[i];
|
466
|
+
}
|
467
|
+
}
|
468
|
+
void set_memory_clock_rate(unsigned int memory_clock_rate)
|
469
|
+
{
|
470
|
+
_memory_clock_rate = memory_clock_rate;
|
471
|
+
}
|
472
|
+
void set_memory_bus_width(unsigned int memory_bus_width)
|
473
|
+
{
|
474
|
+
_memory_bus_width = memory_bus_width;
|
475
|
+
}
|
476
|
+
void
|
477
|
+
set_max_register_size_per_work_group(int max_register_size_per_work_group)
|
478
|
+
{
|
479
|
+
_max_register_size_per_work_group = max_register_size_per_work_group;
|
480
|
+
}
|
481
|
+
void set_device_id(uint32_t device_id)
|
482
|
+
{
|
483
|
+
_device_id = device_id;
|
484
|
+
}
|
485
|
+
void set_uuid(std::array<unsigned char, 16> uuid)
|
486
|
+
{
|
487
|
+
_uuid = std::move(uuid);
|
488
|
+
}
|
489
|
+
void set_global_mem_cache_size(unsigned int global_mem_cache_size)
|
490
|
+
{
|
491
|
+
_global_mem_cache_size = global_mem_cache_size;
|
492
|
+
}
|
493
|
+
|
494
|
+
private:
|
495
|
+
char _name[256];
|
496
|
+
int _max_work_item_sizes_i[3];
|
497
|
+
bool _host_unified_memory = false;
|
498
|
+
int _major;
|
499
|
+
int _minor;
|
500
|
+
int _integrated = 0;
|
501
|
+
int _frequency;
|
502
|
+
// Set estimated value 3200000 kHz as default value.
|
503
|
+
unsigned int _memory_clock_rate = 3200000;
|
504
|
+
// Set estimated value 64 bits as default value.
|
505
|
+
unsigned int _memory_bus_width = 64;
|
506
|
+
unsigned int _global_mem_cache_size;
|
507
|
+
int _max_compute_units;
|
508
|
+
int _max_work_group_size;
|
509
|
+
int _max_sub_group_size;
|
510
|
+
int _max_work_items_per_compute_unit;
|
511
|
+
int _max_register_size_per_work_group;
|
512
|
+
size_t _global_mem_size;
|
513
|
+
size_t _local_mem_size;
|
514
|
+
size_t _max_mem_alloc_size;
|
515
|
+
size_t _max_nd_range_size[3];
|
516
|
+
int _max_nd_range_size_i[3];
|
517
|
+
uint32_t _device_id;
|
518
|
+
std::array<unsigned char, 16> _uuid;
|
519
|
+
};
|
520
|
+
|
521
|
+
static int get_major_version(const sycl::device &dev)
|
522
|
+
{
|
523
|
+
int major, minor;
|
524
|
+
detail::get_version(dev, major, minor);
|
525
|
+
return major;
|
526
|
+
}
|
527
|
+
|
528
|
+
static int get_minor_version(const sycl::device &dev)
|
529
|
+
{
|
530
|
+
int major, minor;
|
531
|
+
detail::get_version(dev, major, minor);
|
532
|
+
return minor;
|
533
|
+
}
|
534
|
+
|
535
|
+
static void get_device_info(device_info &out, const sycl::device &dev)
|
536
|
+
{
|
537
|
+
device_info prop;
|
538
|
+
prop.set_name(dev.get_info<sycl::info::device::name>().c_str());
|
539
|
+
|
540
|
+
int major, minor;
|
541
|
+
detail::get_version(dev, major, minor);
|
542
|
+
prop.set_major_version(major);
|
543
|
+
prop.set_minor_version(minor);
|
544
|
+
|
545
|
+
prop.set_max_work_item_sizes(
|
546
|
+
#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902)
|
547
|
+
// oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes
|
548
|
+
// is an enum class element
|
549
|
+
dev.get_info<sycl::info::device::max_work_item_sizes>());
|
550
|
+
#else
|
551
|
+
// SYCL 2020-conformant code, max_work_item_sizes is a struct templated by
|
552
|
+
// an int
|
553
|
+
dev.get_info<sycl::info::device::max_work_item_sizes<3>>());
|
554
|
+
#endif
|
555
|
+
prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations));
|
556
|
+
|
557
|
+
prop.set_max_clock_frequency(
|
558
|
+
dev.get_info<sycl::info::device::max_clock_frequency>() * 1000);
|
559
|
+
|
560
|
+
prop.set_max_compute_units(
|
561
|
+
dev.get_info<sycl::info::device::max_compute_units>());
|
562
|
+
prop.set_max_work_group_size(
|
563
|
+
dev.get_info<sycl::info::device::max_work_group_size>());
|
564
|
+
prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
|
565
|
+
prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
|
566
|
+
prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());
|
567
|
+
|
568
|
+
#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
|
569
|
+
if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
|
570
|
+
{
|
571
|
+
unsigned int tmp =
|
572
|
+
dev.get_info<sycl::ext::intel::info::device::memory_clock_rate>();
|
573
|
+
if (tmp != 0)
|
574
|
+
prop.set_memory_clock_rate(1000 * tmp);
|
575
|
+
}
|
576
|
+
if (dev.has(sycl::aspect::ext_intel_memory_bus_width))
|
577
|
+
{
|
578
|
+
prop.set_memory_bus_width(
|
579
|
+
dev.get_info<sycl::ext::intel::info::device::memory_bus_width>());
|
580
|
+
}
|
581
|
+
if (dev.has(sycl::aspect::ext_intel_device_id))
|
582
|
+
{
|
583
|
+
prop.set_device_id(
|
584
|
+
dev.get_info<sycl::ext::intel::info::device::device_id>());
|
585
|
+
}
|
586
|
+
if (dev.has(sycl::aspect::ext_intel_device_info_uuid))
|
587
|
+
{
|
588
|
+
prop.set_uuid(dev.get_info<sycl::ext::intel::info::device::uuid>());
|
589
|
+
}
|
590
|
+
#elif defined(_MSC_VER) && !defined(__clang__)
|
591
|
+
#pragma message("get_device_info: querying memory_clock_rate and \
|
592
|
+
memory_bus_width are not supported by the compiler used. \
|
593
|
+
Use 3200000 kHz as memory_clock_rate default value. \
|
594
|
+
Use 64 bits as memory_bus_width default value.")
|
595
|
+
#else
|
596
|
+
#warning "get_device_info: querying memory_clock_rate and \
|
597
|
+
memory_bus_width are not supported by the compiler used. \
|
598
|
+
Use 3200000 kHz as memory_clock_rate default value. \
|
599
|
+
Use 64 bits as memory_bus_width default value."
|
600
|
+
#endif
|
601
|
+
|
602
|
+
size_t max_sub_group_size = 1;
|
603
|
+
std::vector<size_t> sub_group_sizes =
|
604
|
+
dev.get_info<sycl::info::device::sub_group_sizes>();
|
605
|
+
|
606
|
+
for (const auto &sub_group_size : sub_group_sizes)
|
607
|
+
{
|
608
|
+
if (max_sub_group_size < sub_group_size)
|
609
|
+
max_sub_group_size = sub_group_size;
|
610
|
+
}
|
611
|
+
|
612
|
+
prop.set_max_sub_group_size(max_sub_group_size);
|
613
|
+
|
614
|
+
prop.set_max_work_items_per_compute_unit(
|
615
|
+
dev.get_info<sycl::info::device::max_work_group_size>());
|
616
|
+
int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF};
|
617
|
+
prop.set_max_nd_range_size(max_nd_range_size);
|
618
|
+
|
619
|
+
// Estimates max register size per work group, feel free to update the value
|
620
|
+
// according to device properties.
|
621
|
+
prop.set_max_register_size_per_work_group(65536);
|
622
|
+
|
623
|
+
prop.set_global_mem_cache_size(
|
624
|
+
dev.get_info<sycl::info::device::global_mem_cache_size>());
|
625
|
+
out = prop;
|
626
|
+
}
|
627
|
+
|
628
|
+
/// dpct device extension
|
629
|
+
class device_ext : public sycl::device {
|
630
|
+
typedef std::mutex mutex_type;
|
631
|
+
|
632
|
+
public:
|
633
|
+
device_ext() : sycl::device() {}
|
634
|
+
~device_ext() {
|
635
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
636
|
+
clear_queues();
|
637
|
+
}
|
638
|
+
device_ext(const sycl::device &base) : sycl::device(base) {
|
639
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
640
|
+
init_queues();
|
641
|
+
}
|
642
|
+
|
643
|
+
int is_native_atomic_supported() { return 0; }
|
644
|
+
int get_major_version() const { return dpct::get_major_version(*this); }
|
645
|
+
|
646
|
+
int get_minor_version() const { return dpct::get_minor_version(*this); }
|
647
|
+
|
648
|
+
int get_max_compute_units() const {
|
649
|
+
return get_device_info().get_max_compute_units();
|
650
|
+
}
|
651
|
+
|
652
|
+
/// Return the maximum clock frequency of this device in KHz.
|
653
|
+
int get_max_clock_frequency() const {
|
654
|
+
return get_device_info().get_max_clock_frequency();
|
655
|
+
}
|
656
|
+
|
657
|
+
int get_integrated() const { return get_device_info().get_integrated(); }
|
658
|
+
|
659
|
+
int get_max_sub_group_size() const {
|
660
|
+
return get_device_info().get_max_sub_group_size();
|
661
|
+
}
|
662
|
+
|
663
|
+
int get_max_register_size_per_work_group() const {
|
664
|
+
return get_device_info().get_max_register_size_per_work_group();
|
665
|
+
}
|
666
|
+
|
667
|
+
int get_max_work_group_size() const {
|
668
|
+
return get_device_info().get_max_work_group_size();
|
669
|
+
}
|
670
|
+
|
671
|
+
int get_mem_base_addr_align() const {
|
672
|
+
return get_info<sycl::info::device::mem_base_addr_align>();
|
673
|
+
}
|
674
|
+
|
675
|
+
size_t get_global_mem_size() const {
|
676
|
+
return get_device_info().get_global_mem_size();
|
677
|
+
}
|
678
|
+
|
679
|
+
size_t get_max_mem_alloc_size() const {
|
680
|
+
return get_device_info().get_max_mem_alloc_size();
|
681
|
+
}
|
682
|
+
|
683
|
+
/// Get the number of bytes of free and total memory on the SYCL device.
|
684
|
+
/// \param [out] free_memory The number of bytes of free memory on the
|
685
|
+
/// SYCL device. \param [out] total_memory The number of bytes of total
|
686
|
+
/// memory on the SYCL device.
|
687
|
+
void get_memory_info(size_t &free_memory, size_t &total_memory) {
|
688
|
+
total_memory = get_device_info().get_global_mem_size();
|
689
|
+
const char *warning_info =
|
690
|
+
"get_memory_info: [warning] ext_intel_free_memory is not "
|
691
|
+
"supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
|
692
|
+
"use total memory as free memory";
|
693
|
+
#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
|
694
|
+
if (!has(sycl::aspect::ext_intel_free_memory)) {
|
695
|
+
std::cerr << warning_info << std::endl;
|
696
|
+
free_memory = total_memory;
|
697
|
+
} else {
|
698
|
+
free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
|
699
|
+
}
|
700
|
+
#else
|
701
|
+
std::cerr << warning_info << std::endl;
|
702
|
+
free_memory = total_memory;
|
703
|
+
#if defined(_MSC_VER) && !defined(__clang__)
|
704
|
+
#pragma message("Querying the number of bytes of free memory is not supported")
|
705
|
+
#else
|
706
|
+
#warning "Querying the number of bytes of free memory is not supported"
|
707
|
+
#endif
|
708
|
+
#endif
|
709
|
+
}
|
710
|
+
|
711
|
+
void get_device_info(device_info &out) const {
|
712
|
+
dpct::get_device_info(out, *this);
|
713
|
+
}
|
714
|
+
|
715
|
+
device_info get_device_info() const {
|
716
|
+
device_info prop;
|
717
|
+
dpct::get_device_info(prop, *this);
|
718
|
+
return prop;
|
719
|
+
}
|
720
|
+
|
721
|
+
void reset() {
|
722
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
723
|
+
clear_queues();
|
724
|
+
init_queues();
|
725
|
+
}
|
726
|
+
|
727
|
+
sycl::queue &in_order_queue() { return _q_in_order; }
|
728
|
+
|
729
|
+
sycl::queue &out_of_order_queue() { return _q_out_of_order; }
|
730
|
+
|
731
|
+
sycl::queue &default_queue() { return in_order_queue(); }
|
732
|
+
|
733
|
+
void queues_wait_and_throw() {
|
734
|
+
std::unique_lock<mutex_type> lock(m_mutex);
|
735
|
+
lock.unlock();
|
736
|
+
for (auto &q : _queues) {
|
737
|
+
q.wait_and_throw();
|
738
|
+
}
|
739
|
+
// Guard the destruct of current_queues to make sure the ref count is
|
740
|
+
// safe.
|
741
|
+
lock.lock();
|
742
|
+
}
|
743
|
+
|
744
|
+
sycl::queue create_queue(bool enable_exception_handler = false) {
|
745
|
+
return create_in_order_queue(enable_exception_handler);
|
746
|
+
}
|
747
|
+
|
748
|
+
sycl::queue create_queue(sycl::device device,
|
749
|
+
bool enable_exception_handler = false) {
|
750
|
+
return create_in_order_queue(device, enable_exception_handler);
|
751
|
+
}
|
752
|
+
|
753
|
+
sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
|
754
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
755
|
+
return create_queue_impl(enable_exception_handler,
|
756
|
+
sycl::property::queue::in_order());
|
757
|
+
}
|
758
|
+
|
759
|
+
sycl::queue create_in_order_queue(sycl::device device,
|
760
|
+
bool enable_exception_handler = false) {
|
761
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
762
|
+
return create_queue_impl(device, enable_exception_handler,
|
763
|
+
sycl::property::queue::in_order());
|
764
|
+
}
|
765
|
+
|
766
|
+
sycl::queue create_out_of_order_queue(
|
767
|
+
bool enable_exception_handler = false) {
|
768
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
769
|
+
return create_queue_impl(enable_exception_handler);
|
770
|
+
}
|
771
|
+
|
772
|
+
void destroy_queue(sycl::queue queue) {
|
773
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
774
|
+
_queues.erase(std::remove_if(_queues.begin(), _queues.end(),
|
775
|
+
[=](const sycl::queue &q) -> bool
|
776
|
+
{
|
777
|
+
return q == queue;
|
778
|
+
}),
|
779
|
+
_queues.end());
|
780
|
+
}
|
781
|
+
void set_saved_queue(sycl::queue q) {
|
782
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
783
|
+
_saved_queue = q;
|
784
|
+
}
|
785
|
+
sycl::queue get_saved_queue() const {
|
786
|
+
std::lock_guard<mutex_type> lock(m_mutex);
|
787
|
+
return _saved_queue;
|
788
|
+
}
|
789
|
+
|
790
|
+
private:
|
791
|
+
void clear_queues() { _queues.clear(); }
|
792
|
+
|
793
|
+
void init_queues() {
|
794
|
+
_q_in_order =
|
795
|
+
create_queue_impl(true, sycl::property::queue::in_order());
|
796
|
+
_q_out_of_order = create_queue_impl(true);
|
797
|
+
_saved_queue = default_queue();
|
798
|
+
}
|
799
|
+
|
800
|
+
/// Caller should acquire resource \p m_mutex before calling this
|
801
|
+
/// function.
|
802
|
+
template <class... Properties>
|
803
|
+
sycl::queue create_queue_impl(bool enable_exception_handler,
|
804
|
+
Properties... properties) {
|
805
|
+
sycl::async_handler eh = {};
|
806
|
+
if (enable_exception_handler) {
|
807
|
+
eh = exception_handler;
|
808
|
+
}
|
809
|
+
_queues.push_back(sycl::queue(
|
810
|
+
*this, eh,
|
811
|
+
sycl::property_list(
|
812
|
+
#ifdef DPCT_PROFILING_ENABLED
|
813
|
+
sycl::property::queue::enable_profiling(),
|
814
|
+
#endif
|
815
|
+
properties...)));
|
816
|
+
|
817
|
+
return _queues.back();
|
818
|
+
}
|
819
|
+
|
820
|
+
template <class... Properties>
|
821
|
+
sycl::queue create_queue_impl(sycl::device device,
|
822
|
+
bool enable_exception_handler,
|
823
|
+
Properties... properties) {
|
824
|
+
sycl::async_handler eh = {};
|
825
|
+
if (enable_exception_handler) {
|
826
|
+
eh = exception_handler;
|
827
|
+
}
|
828
|
+
_queues.push_back(sycl::queue(
|
829
|
+
device, eh,
|
830
|
+
sycl::property_list(
|
831
|
+
#ifdef DPCT_PROFILING_ENABLED
|
832
|
+
sycl::property::queue::enable_profiling(),
|
833
|
+
#endif
|
834
|
+
properties...)));
|
835
|
+
|
836
|
+
return _queues.back();
|
837
|
+
}
|
838
|
+
|
839
|
+
void get_version(int &major, int &minor) const {
|
840
|
+
detail::get_version(*this, major, minor);
|
841
|
+
}
|
842
|
+
sycl::queue _q_in_order, _q_out_of_order;
|
843
|
+
sycl::queue _saved_queue;
|
844
|
+
std::vector<sycl::queue> _queues;
|
845
|
+
mutable mutex_type m_mutex;
|
846
|
+
};
|
847
|
+
|
848
|
+
|
849
|
+
/// device manager
|
850
|
+
class dev_mgr
|
851
|
+
{
|
852
|
+
public:
|
853
|
+
device_ext ¤t_device()
|
854
|
+
{
|
855
|
+
unsigned int dev_id = current_device_id();
|
856
|
+
check_id(dev_id);
|
857
|
+
return *_devs[dev_id];
|
858
|
+
}
|
859
|
+
device_ext &cpu_device() const
|
860
|
+
{
|
861
|
+
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
862
|
+
if (_cpu_device == -1)
|
863
|
+
{
|
864
|
+
throw std::runtime_error("no valid cpu device");
|
865
|
+
}
|
866
|
+
else
|
867
|
+
{
|
868
|
+
return *_devs[_cpu_device];
|
869
|
+
}
|
870
|
+
}
|
871
|
+
device_ext &get_device(unsigned int id) const
|
872
|
+
{
|
873
|
+
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
874
|
+
check_id(id);
|
875
|
+
return *_devs[id];
|
876
|
+
}
|
877
|
+
unsigned int current_device_id() const
|
878
|
+
{
|
879
|
+
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
880
|
+
auto it = _thread2dev_map.find(get_tid());
|
881
|
+
if (it != _thread2dev_map.end())
|
882
|
+
return it->second;
|
883
|
+
return DEFAULT_DEVICE_ID;
|
884
|
+
}
|
885
|
+
|
886
|
+
/// Select device with a device ID.
|
887
|
+
/// \param [in] id The id of the device which can
|
888
|
+
/// be obtained through get_device_id(const sycl::device).
|
889
|
+
void select_device(unsigned int id)
|
890
|
+
{
|
891
|
+
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
892
|
+
check_id(id);
|
893
|
+
_thread2dev_map[get_tid()] = id;
|
894
|
+
}
|
895
|
+
unsigned int device_count() { return _devs.size(); }
|
896
|
+
|
897
|
+
unsigned int get_device_id(const sycl::device &dev)
|
898
|
+
{
|
899
|
+
unsigned int id = 0;
|
900
|
+
for (auto &dev_item : _devs)
|
901
|
+
{
|
902
|
+
if (*dev_item == dev)
|
903
|
+
{
|
904
|
+
return id;
|
905
|
+
}
|
906
|
+
id++;
|
907
|
+
}
|
908
|
+
return -1;
|
909
|
+
}
|
910
|
+
|
911
|
+
inline std::string get_preferred_gpu_platform_name() {
|
912
|
+
std::string result;
|
913
|
+
|
914
|
+
std::string filter = "";
|
915
|
+
char* env = getenv("ONEAPI_DEVICE_SELECTOR");
|
916
|
+
if (env) {
|
917
|
+
if (std::strstr(env, "level_zero")) {
|
918
|
+
filter = "level-zero";
|
919
|
+
}
|
920
|
+
else if (std::strstr(env, "opencl")) {
|
921
|
+
filter = "opencl";
|
922
|
+
}
|
923
|
+
else if (std::strstr(env, "cuda")) {
|
924
|
+
filter = "cuda";
|
925
|
+
}
|
926
|
+
else if (std::strstr(env, "hip")) {
|
927
|
+
filter = "hip";
|
928
|
+
}
|
929
|
+
else {
|
930
|
+
throw std::runtime_error("invalid device filter: " + std::string(env));
|
931
|
+
}
|
932
|
+
} else {
|
933
|
+
auto default_device = sycl::device(sycl::default_selector_v);
|
934
|
+
auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
|
935
|
+
|
936
|
+
if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
|
937
|
+
filter = "level-zero";
|
938
|
+
}
|
939
|
+
else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
|
940
|
+
filter = "cuda";
|
941
|
+
}
|
942
|
+
else if (std::strstr(default_platform_name.c_str(), "HIP")) {
|
943
|
+
filter = "hip";
|
944
|
+
}
|
945
|
+
}
|
946
|
+
|
947
|
+
auto platform_list = sycl::platform::get_platforms();
|
948
|
+
|
949
|
+
for (const auto& platform : platform_list) {
|
950
|
+
auto devices = platform.get_devices();
|
951
|
+
auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
|
952
|
+
return d.is_gpu();
|
953
|
+
});
|
954
|
+
|
955
|
+
if (gpu_dev == devices.end()) {
|
956
|
+
// cout << "platform [" << platform_name
|
957
|
+
// << "] does not contain GPU devices, skipping\n";
|
958
|
+
continue;
|
959
|
+
}
|
960
|
+
|
961
|
+
auto platform_name = platform.get_info<sycl::info::platform::name>();
|
962
|
+
std::string platform_name_low_case;
|
963
|
+
platform_name_low_case.resize(platform_name.size());
|
964
|
+
|
965
|
+
std::transform(
|
966
|
+
platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
|
967
|
+
|
968
|
+
if (platform_name_low_case.find(filter) == std::string::npos) {
|
969
|
+
// cout << "platform [" << platform_name
|
970
|
+
// << "] does not match with requested "
|
971
|
+
// << filter << ", skipping\n";
|
972
|
+
continue;
|
973
|
+
}
|
974
|
+
|
975
|
+
result = platform_name;
|
976
|
+
}
|
977
|
+
|
978
|
+
if (result.empty())
|
979
|
+
throw std::runtime_error("can not find preferred GPU platform");
|
980
|
+
|
981
|
+
return result;
|
982
|
+
}
|
983
|
+
|
984
|
+
template <class DeviceSelector>
|
985
|
+
std::enable_if_t<
|
986
|
+
std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
|
987
|
+
select_device(const DeviceSelector &selector = sycl::gpu_selector_v)
|
988
|
+
{
|
989
|
+
sycl::device selected_device = sycl::device(selector);
|
990
|
+
unsigned int selected_device_id = get_device_id(selected_device);
|
991
|
+
select_device(selected_device_id);
|
992
|
+
}
|
993
|
+
|
994
|
+
/// Returns the instance of device manager singleton.
|
995
|
+
static dev_mgr &instance()
|
996
|
+
{
|
997
|
+
static dev_mgr d_m;
|
998
|
+
return d_m;
|
999
|
+
}
|
1000
|
+
dev_mgr(const dev_mgr &) = delete;
|
1001
|
+
dev_mgr &operator=(const dev_mgr &) = delete;
|
1002
|
+
dev_mgr(dev_mgr &&) = delete;
|
1003
|
+
dev_mgr &operator=(dev_mgr &&) = delete;
|
1004
|
+
|
1005
|
+
private:
|
1006
|
+
mutable std::recursive_mutex m_mutex;
|
1007
|
+
static bool compare_dev(sycl::device &device1, sycl::device &device2)
|
1008
|
+
{
|
1009
|
+
sycl::backend backend1 = device1.get_backend();
|
1010
|
+
sycl::backend backend2 = device2.get_backend();
|
1011
|
+
// levelzero backends always come first
|
1012
|
+
if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true;
|
1013
|
+
if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false;
|
1014
|
+
dpct::device_info prop1;
|
1015
|
+
dpct::get_device_info(prop1, device1);
|
1016
|
+
dpct::device_info prop2;
|
1017
|
+
dpct::get_device_info(prop2, device2);
|
1018
|
+
return prop1.get_max_compute_units() > prop2.get_max_compute_units();
|
1019
|
+
}
|
1020
|
+
static int convert_backend_index(std::string & backend) {
|
1021
|
+
if (backend == "ext_oneapi_level_zero:gpu") return 0;
|
1022
|
+
if (backend == "opencl:gpu") return 1;
|
1023
|
+
if (backend == "ext_oneapi_cuda:gpu") return 2;
|
1024
|
+
if (backend == "ext_oneapi_hip:gpu") return 3;
|
1025
|
+
if (backend == "opencl:cpu") return 4;
|
1026
|
+
if (backend == "opencl:acc") return 5;
|
1027
|
+
printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
|
1028
|
+
GGML_ABORT("fatal error");
|
1029
|
+
}
|
1030
|
+
static bool compare_backend(std::string &backend1, std::string &backend2) {
|
1031
|
+
return convert_backend_index(backend1) < convert_backend_index(backend2);
|
1032
|
+
}
|
1033
|
+
dev_mgr()
|
1034
|
+
{
|
1035
|
+
sycl::device default_device =
|
1036
|
+
sycl::device(sycl::default_selector_v);
|
1037
|
+
_devs.push_back(std::make_shared<device_ext>(default_device));
|
1038
|
+
|
1039
|
+
std::vector<sycl::device> sycl_all_devs;
|
1040
|
+
// Collect other devices except for the default device.
|
1041
|
+
if (default_device.is_cpu())
|
1042
|
+
_cpu_device = 0;
|
1043
|
+
|
1044
|
+
auto Platforms = sycl::platform::get_platforms();
|
1045
|
+
// Keep track of the number of devices per backend
|
1046
|
+
std::map<sycl::backend, size_t> DeviceNums;
|
1047
|
+
std::map<std::string, std::vector<sycl::device>> backend_devices;
|
1048
|
+
auto preferred_platform_name = get_preferred_gpu_platform_name();
|
1049
|
+
|
1050
|
+
while (!Platforms.empty()) {
|
1051
|
+
auto Platform = Platforms.back();
|
1052
|
+
Platforms.pop_back();
|
1053
|
+
auto platform_name = Platform.get_info<sycl::info::platform::name>();
|
1054
|
+
if (platform_name.compare(preferred_platform_name) != 0) {
|
1055
|
+
continue;
|
1056
|
+
}
|
1057
|
+
auto devices = Platform.get_devices();
|
1058
|
+
std::string backend_type = get_device_backend_and_type(devices[0]);
|
1059
|
+
for (const auto &device : devices) {
|
1060
|
+
backend_devices[backend_type].push_back(device);
|
1061
|
+
}
|
1062
|
+
}
|
1063
|
+
|
1064
|
+
std::vector<std::string> keys;
|
1065
|
+
for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
|
1066
|
+
keys.push_back(it->first);
|
1067
|
+
}
|
1068
|
+
std::sort(keys.begin(), keys.end(), compare_backend);
|
1069
|
+
|
1070
|
+
for (auto &key : keys) {
|
1071
|
+
std::vector<sycl::device> devs = backend_devices[key];
|
1072
|
+
std::sort(devs.begin(), devs.end(), compare_dev);
|
1073
|
+
for (const auto &dev : devs) {
|
1074
|
+
sycl_all_devs.push_back(dev);
|
1075
|
+
}
|
1076
|
+
}
|
1077
|
+
|
1078
|
+
for (auto &dev : sycl_all_devs)
|
1079
|
+
{
|
1080
|
+
if (dev == default_device)
|
1081
|
+
{
|
1082
|
+
continue;
|
1083
|
+
}
|
1084
|
+
_devs.push_back(std::make_shared<device_ext>(dev));
|
1085
|
+
if (_cpu_device == -1 && dev.is_cpu())
|
1086
|
+
{
|
1087
|
+
_cpu_device = _devs.size() - 1;
|
1088
|
+
}
|
1089
|
+
}
|
1090
|
+
}
|
1091
|
+
void check_id(unsigned int id) const
|
1092
|
+
{
|
1093
|
+
if (id >= _devs.size())
|
1094
|
+
{
|
1095
|
+
throw std::runtime_error("invalid device id");
|
1096
|
+
}
|
1097
|
+
}
|
1098
|
+
std::vector<std::shared_ptr<device_ext>> _devs;
|
1099
|
+
/// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current
|
1100
|
+
/// thread id in _thread2dev_map, which means default device should be used
|
1101
|
+
/// for the current thread.
|
1102
|
+
const unsigned int DEFAULT_DEVICE_ID = 0;
|
1103
|
+
/// thread-id to device-id map.
|
1104
|
+
std::map<unsigned int, unsigned int> _thread2dev_map;
|
1105
|
+
int _cpu_device = -1;
|
1106
|
+
};
|
1107
|
+
|
1108
|
+
static inline sycl::queue &get_default_queue()
|
1109
|
+
{
|
1110
|
+
return dev_mgr::instance().current_device().default_queue();
|
1111
|
+
}
|
1112
|
+
|
1113
|
+
namespace detail
|
1114
|
+
{
|
1115
|
+
enum class pointer_access_attribute
|
1116
|
+
{
|
1117
|
+
host_only = 0,
|
1118
|
+
device_only,
|
1119
|
+
host_device,
|
1120
|
+
end
|
1121
|
+
};
|
1122
|
+
|
1123
|
+
static pointer_access_attribute get_pointer_attribute(sycl::queue &q,
|
1124
|
+
const void *ptr)
|
1125
|
+
{
|
1126
|
+
switch (sycl::get_pointer_type(ptr, q.get_context()))
|
1127
|
+
{
|
1128
|
+
case sycl::usm::alloc::unknown:
|
1129
|
+
return pointer_access_attribute::host_only;
|
1130
|
+
case sycl::usm::alloc::device:
|
1131
|
+
return pointer_access_attribute::device_only;
|
1132
|
+
case sycl::usm::alloc::shared:
|
1133
|
+
case sycl::usm::alloc::host:
|
1134
|
+
return pointer_access_attribute::host_device;
|
1135
|
+
}
|
1136
|
+
}
|
1137
|
+
|
1138
|
+
template <typename ArgT>
|
1139
|
+
inline constexpr std::uint64_t get_type_combination_id(ArgT Val)
|
1140
|
+
{
|
1141
|
+
static_assert((unsigned char)library_data_t::library_data_t_size <=
|
1142
|
+
std::numeric_limits<unsigned char>::max() &&
|
1143
|
+
"library_data_t size exceeds limit.");
|
1144
|
+
static_assert(std::is_same_v<ArgT, library_data_t>, "Unsupported ArgT");
|
1145
|
+
return (std::uint64_t)Val;
|
1146
|
+
}
|
1147
|
+
|
1148
|
+
template <typename FirstT, typename... RestT>
|
1149
|
+
inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal,
|
1150
|
+
RestT... RestVal)
|
1151
|
+
{
|
1152
|
+
static_assert((std::uint8_t)library_data_t::library_data_t_size <=
|
1153
|
+
std::numeric_limits<unsigned char>::max() &&
|
1154
|
+
"library_data_t size exceeds limit.");
|
1155
|
+
static_assert(sizeof...(RestT) <= 8 && "Too many parameters");
|
1156
|
+
static_assert(std::is_same_v<FirstT, library_data_t>, "Unsupported FirstT");
|
1157
|
+
return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal);
|
1158
|
+
}
|
1159
|
+
|
1160
|
+
class mem_mgr
|
1161
|
+
{
|
1162
|
+
mem_mgr()
|
1163
|
+
{
|
1164
|
+
// Reserved address space, no real memory allocation happens here.
|
1165
|
+
#if defined(__linux__)
|
1166
|
+
mapped_address_space =
|
1167
|
+
(byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE,
|
1168
|
+
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
|
1169
|
+
#elif defined(_WIN64)
|
1170
|
+
mapped_address_space = (byte_t *)VirtualAlloc(
|
1171
|
+
NULL, // NULL specified as the base address parameter
|
1172
|
+
mapped_region_size, // Size of allocation
|
1173
|
+
MEM_RESERVE, // Allocate reserved pages
|
1174
|
+
PAGE_NOACCESS); // Protection = no access
|
1175
|
+
#else
|
1176
|
+
#error "Only support Windows and Linux."
|
1177
|
+
#endif
|
1178
|
+
next_free = mapped_address_space;
|
1179
|
+
}
|
1180
|
+
|
1181
|
+
public:
|
1182
|
+
using buffer_id_t = int;
|
1183
|
+
|
1184
|
+
struct allocation
|
1185
|
+
{
|
1186
|
+
buffer_t buffer;
|
1187
|
+
byte_t *alloc_ptr;
|
1188
|
+
size_t size;
|
1189
|
+
};
|
1190
|
+
|
1191
|
+
~mem_mgr()
|
1192
|
+
{
|
1193
|
+
#if defined(__linux__)
|
1194
|
+
munmap(mapped_address_space, mapped_region_size);
|
1195
|
+
#elif defined(_WIN64)
|
1196
|
+
VirtualFree(mapped_address_space, 0, MEM_RELEASE);
|
1197
|
+
#else
|
1198
|
+
#error "Only support Windows and Linux."
|
1199
|
+
#endif
|
1200
|
+
}
|
1201
|
+
|
1202
|
+
mem_mgr(const mem_mgr &) = delete;
|
1203
|
+
mem_mgr &operator=(const mem_mgr &) = delete;
|
1204
|
+
mem_mgr(mem_mgr &&) = delete;
|
1205
|
+
mem_mgr &operator=(mem_mgr &&) = delete;
|
1206
|
+
|
1207
|
+
/// Allocate
|
1208
|
+
void *mem_alloc(size_t size)
|
1209
|
+
{
|
1210
|
+
if (!size)
|
1211
|
+
return nullptr;
|
1212
|
+
std::lock_guard<std::mutex> lock(m_mutex);
|
1213
|
+
if (next_free + size > mapped_address_space + mapped_region_size)
|
1214
|
+
{
|
1215
|
+
throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool");
|
1216
|
+
}
|
1217
|
+
// Allocation
|
1218
|
+
sycl::range<1> r(size);
|
1219
|
+
buffer_t buf(r);
|
1220
|
+
allocation A{buf, next_free, size};
|
1221
|
+
// Map allocation to device pointer
|
1222
|
+
void *result = next_free;
|
1223
|
+
m_map.emplace(next_free + size, A);
|
1224
|
+
// Update pointer to the next free space.
|
1225
|
+
next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1);
|
1226
|
+
|
1227
|
+
return result;
|
1228
|
+
}
|
1229
|
+
|
1230
|
+
/// Deallocate
|
1231
|
+
void mem_free(const void *ptr)
|
1232
|
+
{
|
1233
|
+
if (!ptr)
|
1234
|
+
return;
|
1235
|
+
std::lock_guard<std::mutex> lock(m_mutex);
|
1236
|
+
auto it = get_map_iterator(ptr);
|
1237
|
+
m_map.erase(it);
|
1238
|
+
}
|
1239
|
+
|
1240
|
+
/// map: device pointer -> allocation(buffer, alloc_ptr, size)
|
1241
|
+
allocation translate_ptr(const void *ptr)
|
1242
|
+
{
|
1243
|
+
std::lock_guard<std::mutex> lock(m_mutex);
|
1244
|
+
auto it = get_map_iterator(ptr);
|
1245
|
+
return it->second;
|
1246
|
+
}
|
1247
|
+
|
1248
|
+
/// Check if the pointer represents device pointer or not.
|
1249
|
+
bool is_device_ptr(const void *ptr) const
|
1250
|
+
{
|
1251
|
+
std::lock_guard<std::mutex> lock(m_mutex);
|
1252
|
+
return (mapped_address_space <= ptr) &&
|
1253
|
+
(ptr < mapped_address_space + mapped_region_size);
|
1254
|
+
}
|
1255
|
+
|
1256
|
+
/// Returns the instance of memory manager singleton.
|
1257
|
+
static mem_mgr &instance()
|
1258
|
+
{
|
1259
|
+
static mem_mgr m;
|
1260
|
+
return m;
|
1261
|
+
}
|
1262
|
+
|
1263
|
+
private:
|
1264
|
+
std::map<byte_t *, allocation> m_map;
|
1265
|
+
mutable std::mutex m_mutex;
|
1266
|
+
byte_t *mapped_address_space;
|
1267
|
+
byte_t *next_free;
|
1268
|
+
const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024;
|
1269
|
+
const size_t alignment = 256;
|
1270
|
+
/// This padding may be defined to some positive value to debug
|
1271
|
+
/// out of bound accesses.
|
1272
|
+
const size_t extra_padding = 0;
|
1273
|
+
|
1274
|
+
std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)
|
1275
|
+
{
|
1276
|
+
auto it = m_map.upper_bound(const_cast<byte_t *>(reinterpret_cast<const byte_t *>(ptr)));
|
1277
|
+
if (it == m_map.end())
|
1278
|
+
{
|
1279
|
+
// Not a virtual pointer.
|
1280
|
+
throw std::runtime_error("can not get buffer from non-virtual pointer");
|
1281
|
+
}
|
1282
|
+
const allocation &alloc = it->second;
|
1283
|
+
if (ptr < alloc.alloc_ptr)
|
1284
|
+
{
|
1285
|
+
// Out of bound.
|
1286
|
+
// This may happen if there's a gap between allocations due to alignment
|
1287
|
+
// or extra padding and pointer points to this gap.
|
1288
|
+
throw std::runtime_error("invalid virtual pointer");
|
1289
|
+
}
|
1290
|
+
return it;
|
1291
|
+
}
|
1292
|
+
};
|
1293
|
+
|
1294
|
+
template <class T, memory_region Memory, size_t Dimension>
|
1295
|
+
class accessor;
|
1296
|
+
template <memory_region Memory, class T = byte_t>
|
1297
|
+
class memory_traits
|
1298
|
+
{
|
1299
|
+
public:
|
1300
|
+
static constexpr sycl::access::target target =
|
1301
|
+
sycl::access::target::device;
|
1302
|
+
static constexpr sycl::access_mode mode =
|
1303
|
+
(Memory == constant) ? sycl::access_mode::read
|
1304
|
+
: sycl::access_mode::read_write;
|
1305
|
+
static constexpr size_t type_size = sizeof(T);
|
1306
|
+
using element_t =
|
1307
|
+
typename std::conditional<Memory == constant, const T, T>::type;
|
1308
|
+
using value_t = typename std::remove_cv<T>::type;
|
1309
|
+
template <size_t Dimension = 1>
|
1310
|
+
using accessor_t = typename std::conditional<
|
1311
|
+
Memory == local, sycl::local_accessor<value_t, Dimension>,
|
1312
|
+
sycl::accessor<T, Dimension, mode, target>>::type;
|
1313
|
+
using pointer_t = T *;
|
1314
|
+
};
|
1315
|
+
|
1316
|
+
static inline void *dpct_malloc(size_t size, sycl::queue &q)
|
1317
|
+
{
|
1318
|
+
return sycl::malloc_device(size, q.get_device(), q.get_context());
|
1319
|
+
}
|
1320
|
+
|
1321
|
+
#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F))
|
1322
|
+
static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z,
|
1323
|
+
sycl::queue &q)
|
1324
|
+
{
|
1325
|
+
pitch = PITCH_DEFAULT_ALIGN(x);
|
1326
|
+
return dpct_malloc(pitch * y * z, q);
|
1327
|
+
}
|
1328
|
+
|
1329
|
+
/**
|
1330
|
+
* @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q.
|
1331
|
+
* @tparam valueT The type of the element to be set.
|
1332
|
+
* @param [in] q The queue in which the operation is done.
|
1333
|
+
* @param [in] dev_ptr Pointer to the virtual device memory address.
|
1334
|
+
* @param [in] value The value to be set.
|
1335
|
+
* @param [in] size Number of elements to be set to the value.
|
1336
|
+
* @return An event representing the memset operation.
|
1337
|
+
*/
|
1338
|
+
template <typename valueT>
|
1339
|
+
static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr,
|
1340
|
+
valueT value, size_t size)
|
1341
|
+
{
|
1342
|
+
return q.fill(dev_ptr, value, size);
|
1343
|
+
}
|
1344
|
+
|
1345
|
+
/**
|
1346
|
+
* @brief Sets \p value to the 3D memory region pointed by \p data in \p q.
|
1347
|
+
* @tparam valueT The type of the element to be set.
|
1348
|
+
* @param [in] q The queue in which the operation is done.
|
1349
|
+
* @param [in] data Pointer to the pitched device memory region.
|
1350
|
+
* @param [in] value The value to be set.
|
1351
|
+
* @param [in] size 3D memory region by number of elements.
|
1352
|
+
* @return An event list representing the memset operations.
|
1353
|
+
*/
|
1354
|
+
template <typename valueT>
|
1355
|
+
static inline std::vector<sycl::event>
|
1356
|
+
dpct_memset(sycl::queue &q, pitched_data data, valueT value,
|
1357
|
+
sycl::range<3> size)
|
1358
|
+
{
|
1359
|
+
std::vector<sycl::event> event_list;
|
1360
|
+
size_t slice = data.get_pitch() * data.get_y();
|
1361
|
+
unsigned char *data_surface = (unsigned char *)data.get_data_ptr();
|
1362
|
+
for (size_t z = 0; z < size.get(2); ++z)
|
1363
|
+
{
|
1364
|
+
unsigned char *data_ptr = data_surface;
|
1365
|
+
for (size_t y = 0; y < size.get(1); ++y)
|
1366
|
+
{
|
1367
|
+
event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0)));
|
1368
|
+
data_ptr += data.get_pitch();
|
1369
|
+
}
|
1370
|
+
data_surface += slice;
|
1371
|
+
}
|
1372
|
+
return event_list;
|
1373
|
+
}
|
1374
|
+
|
1375
|
+
/**
|
1376
|
+
* @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q.
|
1377
|
+
* @tparam valueT The type of the element to be set.
|
1378
|
+
* @param [in] q The queue in which the operation is done.
|
1379
|
+
* @param [in] ptr Pointer to the virtual device memory.
|
1380
|
+
* @param [in] pitch The pitch size by number of elements, including padding.
|
1381
|
+
* @param [in] val The value to be set.
|
1382
|
+
* @param [in] x The width of memory region by number of elements.
|
1383
|
+
* @param [in] y The height of memory region by number of elements.
|
1384
|
+
* @return An event list representing the memset operations.
|
1385
|
+
*/
|
1386
|
+
template <typename valueT>
|
1387
|
+
static inline std::vector<sycl::event>
|
1388
|
+
dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x,
|
1389
|
+
size_t y)
|
1390
|
+
{
|
1391
|
+
return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val,
|
1392
|
+
sycl::range<3>(x, y, 1));
|
1393
|
+
}
|
1394
|
+
|
1395
|
+
static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr,
|
1396
|
+
const void *from_ptr,
|
1397
|
+
memcpy_direction dir)
|
1398
|
+
{
|
1399
|
+
switch (dir)
|
1400
|
+
{
|
1401
|
+
case memcpy_direction::host_to_host:
|
1402
|
+
case memcpy_direction::host_to_device:
|
1403
|
+
case memcpy_direction::device_to_host:
|
1404
|
+
case memcpy_direction::device_to_device:
|
1405
|
+
return dir;
|
1406
|
+
case memcpy_direction::automatic:
|
1407
|
+
{
|
1408
|
+
// table[to_attribute][from_attribute]
|
1409
|
+
static const memcpy_direction
|
1410
|
+
direction_table[static_cast<unsigned>(pointer_access_attribute::end)]
|
1411
|
+
[static_cast<unsigned>(pointer_access_attribute::end)] =
|
1412
|
+
{{memcpy_direction::host_to_host,
|
1413
|
+
memcpy_direction::device_to_host,
|
1414
|
+
memcpy_direction::host_to_host},
|
1415
|
+
{memcpy_direction::host_to_device,
|
1416
|
+
memcpy_direction::device_to_device,
|
1417
|
+
memcpy_direction::device_to_device},
|
1418
|
+
{memcpy_direction::host_to_host,
|
1419
|
+
memcpy_direction::device_to_device,
|
1420
|
+
memcpy_direction::device_to_device}};
|
1421
|
+
return direction_table[static_cast<unsigned>(get_pointer_attribute(
|
1422
|
+
q, to_ptr))][static_cast<unsigned>(get_pointer_attribute(q, from_ptr))];
|
1423
|
+
}
|
1424
|
+
default:
|
1425
|
+
throw std::runtime_error("dpct_memcpy: invalid direction value");
|
1426
|
+
}
|
1427
|
+
}
|
1428
|
+
|
1429
|
+
static sycl::event
|
1430
|
+
dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
|
1431
|
+
memcpy_direction direction,
|
1432
|
+
const std::vector<sycl::event> &dep_events = {})
|
1433
|
+
{
|
1434
|
+
if (!size)
|
1435
|
+
return sycl::event{};
|
1436
|
+
return q.memcpy(to_ptr, from_ptr, size, dep_events);
|
1437
|
+
GGML_UNUSED(direction);
|
1438
|
+
}
|
1439
|
+
|
1440
|
+
// Get actual copy range and make sure it will not exceed range.
|
1441
|
+
static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
|
1442
|
+
size_t pitch)
|
1443
|
+
{
|
1444
|
+
return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
|
1445
|
+
}
|
1446
|
+
|
1447
|
+
static inline size_t get_offset(sycl::id<3> id, size_t slice,
|
1448
|
+
size_t pitch)
|
1449
|
+
{
|
1450
|
+
return slice * id.get(2) + pitch * id.get(1) + id.get(0);
|
1451
|
+
}
|
1452
|
+
|
1453
|
+
/// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
|
1454
|
+
/// and \p from_range to another specified by \p to_ptr and \p to_range.
|
1455
|
+
static inline std::vector<sycl::event>
|
1456
|
+
dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
|
1457
|
+
sycl::range<3> to_range, sycl::range<3> from_range,
|
1458
|
+
sycl::id<3> to_id, sycl::id<3> from_id,
|
1459
|
+
sycl::range<3> size, memcpy_direction direction,
|
1460
|
+
const std::vector<sycl::event> &dep_events = {})
|
1461
|
+
{
|
1462
|
+
// RAII for host pointer
|
1463
|
+
class host_buffer
|
1464
|
+
{
|
1465
|
+
void *_buf;
|
1466
|
+
size_t _size;
|
1467
|
+
sycl::queue &_q;
|
1468
|
+
const std::vector<sycl::event> &_deps; // free operation depends
|
1469
|
+
|
1470
|
+
public:
|
1471
|
+
host_buffer(size_t size, sycl::queue &q,
|
1472
|
+
const std::vector<sycl::event> &deps)
|
1473
|
+
: _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
|
1474
|
+
void *get_ptr() const { return _buf; }
|
1475
|
+
size_t get_size() const { return _size; }
|
1476
|
+
~host_buffer()
|
1477
|
+
{
|
1478
|
+
if (_buf)
|
1479
|
+
{
|
1480
|
+
_q.submit([&](sycl::handler &cgh)
|
1481
|
+
{
|
1482
|
+
cgh.depends_on(_deps);
|
1483
|
+
cgh.host_task([buf = _buf] { std::free(buf); }); });
|
1484
|
+
}
|
1485
|
+
}
|
1486
|
+
};
|
1487
|
+
std::vector<sycl::event> event_list;
|
1488
|
+
|
1489
|
+
size_t to_slice = to_range.get(1) * to_range.get(0),
|
1490
|
+
from_slice = from_range.get(1) * from_range.get(0);
|
1491
|
+
unsigned char *to_surface =
|
1492
|
+
(unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
|
1493
|
+
const unsigned char *from_surface =
|
1494
|
+
(const unsigned char *)from_ptr +
|
1495
|
+
get_offset(from_id, from_slice, from_range.get(0));
|
1496
|
+
|
1497
|
+
if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
|
1498
|
+
{
|
1499
|
+
return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
|
1500
|
+
direction, dep_events)};
|
1501
|
+
}
|
1502
|
+
direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
|
1503
|
+
size_t size_slice = size.get(1) * size.get(0);
|
1504
|
+
switch (direction)
|
1505
|
+
{
|
1506
|
+
case host_to_host:
|
1507
|
+
for (size_t z = 0; z < size.get(2); ++z)
|
1508
|
+
{
|
1509
|
+
unsigned char *to_ptr = to_surface;
|
1510
|
+
const unsigned char *from_ptr = from_surface;
|
1511
|
+
if (to_range.get(0) == from_range.get(0) &&
|
1512
|
+
to_range.get(0) == size.get(0))
|
1513
|
+
{
|
1514
|
+
event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
|
1515
|
+
direction, dep_events));
|
1516
|
+
}
|
1517
|
+
else
|
1518
|
+
{
|
1519
|
+
for (size_t y = 0; y < size.get(1); ++y)
|
1520
|
+
{
|
1521
|
+
event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
|
1522
|
+
direction, dep_events));
|
1523
|
+
to_ptr += to_range.get(0);
|
1524
|
+
from_ptr += from_range.get(0);
|
1525
|
+
}
|
1526
|
+
}
|
1527
|
+
to_surface += to_slice;
|
1528
|
+
from_surface += from_slice;
|
1529
|
+
}
|
1530
|
+
break;
|
1531
|
+
case host_to_device:
|
1532
|
+
{
|
1533
|
+
host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
|
1534
|
+
event_list);
|
1535
|
+
std::vector<sycl::event> host_events;
|
1536
|
+
if (to_slice == size_slice)
|
1537
|
+
{
|
1538
|
+
// Copy host data to a temp host buffer with the shape of target.
|
1539
|
+
host_events =
|
1540
|
+
dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
|
1541
|
+
sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
|
1542
|
+
host_to_host, dep_events);
|
1543
|
+
}
|
1544
|
+
else
|
1545
|
+
{
|
1546
|
+
// Copy host data to a temp host buffer with the shape of target.
|
1547
|
+
host_events = dpct_memcpy(
|
1548
|
+
q, buf.get_ptr(), from_surface, to_range, from_range,
|
1549
|
+
sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
|
1550
|
+
// If has padding data, not sure whether it is useless. So fill temp
|
1551
|
+
// buffer with it.
|
1552
|
+
std::vector<sycl::event>{
|
1553
|
+
dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
|
1554
|
+
device_to_host, dep_events)});
|
1555
|
+
}
|
1556
|
+
// Copy from temp host buffer to device with only one submit.
|
1557
|
+
event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
|
1558
|
+
buf.get_size(), host_to_device,
|
1559
|
+
host_events));
|
1560
|
+
break;
|
1561
|
+
}
|
1562
|
+
case device_to_host:
|
1563
|
+
{
|
1564
|
+
host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
|
1565
|
+
event_list);
|
1566
|
+
// Copy from host temp buffer to host target with reshaping.
|
1567
|
+
event_list = dpct_memcpy(
|
1568
|
+
q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
|
1569
|
+
sycl::id<3>(0, 0, 0), size, host_to_host,
|
1570
|
+
// Copy from device to temp host buffer with only one submit.
|
1571
|
+
std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
|
1572
|
+
buf.get_size(),
|
1573
|
+
device_to_host, dep_events)});
|
1574
|
+
break;
|
1575
|
+
}
|
1576
|
+
case device_to_device:
|
1577
|
+
event_list.push_back(q.submit([&](sycl::handler &cgh){
|
1578
|
+
cgh.depends_on(dep_events);
|
1579
|
+
cgh.parallel_for<class dpct_memcpy_3d_detail>(
|
1580
|
+
size,
|
1581
|
+
[=](sycl::id<3> id) {
|
1582
|
+
to_surface[get_offset(id, to_slice, to_range.get(0))] =
|
1583
|
+
from_surface[get_offset(id, from_slice, from_range.get(0))];
|
1584
|
+
}); }));
|
1585
|
+
break;
|
1586
|
+
default:
|
1587
|
+
throw std::runtime_error("dpct_memcpy: invalid direction value");
|
1588
|
+
}
|
1589
|
+
return event_list;
|
1590
|
+
}
|
1591
|
+
|
1592
|
+
/// memcpy 2D/3D matrix specified by pitched_data.
|
1593
|
+
static inline std::vector<sycl::event>
|
1594
|
+
dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
|
1595
|
+
pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
|
1596
|
+
memcpy_direction direction = automatic)
|
1597
|
+
{
|
1598
|
+
return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
|
1599
|
+
sycl::range<3>(to.get_pitch(), to.get_y(), 1),
|
1600
|
+
sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
|
1601
|
+
size, direction);
|
1602
|
+
}
|
1603
|
+
|
1604
|
+
/// memcpy 2D matrix with pitch.
|
1605
|
+
static inline std::vector<sycl::event>
|
1606
|
+
dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
|
1607
|
+
size_t to_pitch, size_t from_pitch, size_t x, size_t y,
|
1608
|
+
memcpy_direction direction = automatic)
|
1609
|
+
{
|
1610
|
+
return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
|
1611
|
+
sycl::range<3>(from_pitch, y, 1),
|
1612
|
+
sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
|
1613
|
+
sycl::range<3>(x, y, 1), direction);
|
1614
|
+
}
|
1615
|
+
|
1616
|
+
namespace deprecated
|
1617
|
+
{
|
1618
|
+
|
1619
|
+
template <typename T, sycl::usm::alloc AllocKind>
|
1620
|
+
class usm_allocator
|
1621
|
+
{
|
1622
|
+
private:
|
1623
|
+
using Alloc = sycl::usm_allocator<T, AllocKind>;
|
1624
|
+
Alloc _impl;
|
1625
|
+
|
1626
|
+
public:
|
1627
|
+
using value_type = typename std::allocator_traits<Alloc>::value_type;
|
1628
|
+
using pointer = typename std::allocator_traits<Alloc>::pointer;
|
1629
|
+
using const_pointer = typename std::allocator_traits<Alloc>::const_pointer;
|
1630
|
+
using void_pointer = typename std::allocator_traits<Alloc>::void_pointer;
|
1631
|
+
using const_void_pointer =
|
1632
|
+
typename std::allocator_traits<Alloc>::const_void_pointer;
|
1633
|
+
using reference = typename std::allocator_traits<Alloc>::value_type &;
|
1634
|
+
using const_reference =
|
1635
|
+
const typename std::allocator_traits<Alloc>::value_type &;
|
1636
|
+
using difference_type =
|
1637
|
+
typename std::allocator_traits<Alloc>::difference_type;
|
1638
|
+
using size_type = typename std::allocator_traits<Alloc>::size_type;
|
1639
|
+
using propagate_on_container_copy_assignment = typename std::allocator_traits<
|
1640
|
+
Alloc>::propagate_on_container_copy_assignment;
|
1641
|
+
using propagate_on_container_move_assignment = typename std::allocator_traits<
|
1642
|
+
Alloc>::propagate_on_container_move_assignment;
|
1643
|
+
using propagate_on_container_swap =
|
1644
|
+
typename std::allocator_traits<Alloc>::propagate_on_container_swap;
|
1645
|
+
using is_always_equal =
|
1646
|
+
typename std::allocator_traits<Alloc>::is_always_equal;
|
1647
|
+
|
1648
|
+
template <typename U>
|
1649
|
+
struct rebind
|
1650
|
+
{
|
1651
|
+
typedef usm_allocator<U, AllocKind> other;
|
1652
|
+
};
|
1653
|
+
|
1654
|
+
usm_allocator() : _impl(dpct::get_default_queue()) {}
|
1655
|
+
~usm_allocator() {}
|
1656
|
+
usm_allocator(const usm_allocator &other) : _impl(other._impl) {}
|
1657
|
+
usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {}
|
1658
|
+
pointer address(reference r) { return &r; }
|
1659
|
+
const_pointer address(const_reference r) { return &r; }
|
1660
|
+
pointer allocate(size_type cnt, const_void_pointer hint = nullptr)
|
1661
|
+
{
|
1662
|
+
return std::allocator_traits<Alloc>::allocate(_impl, cnt, hint);
|
1663
|
+
}
|
1664
|
+
void deallocate(pointer p, size_type cnt)
|
1665
|
+
{
|
1666
|
+
std::allocator_traits<Alloc>::deallocate(_impl, p, cnt);
|
1667
|
+
}
|
1668
|
+
size_type max_size() const
|
1669
|
+
{
|
1670
|
+
return std::allocator_traits<Alloc>::max_size(_impl);
|
1671
|
+
}
|
1672
|
+
bool operator==(const usm_allocator &other) const { return _impl == other._impl; }
|
1673
|
+
bool operator!=(const usm_allocator &other) const { return _impl != other._impl; }
|
1674
|
+
};
|
1675
|
+
|
1676
|
+
} // namespace deprecated
|
1677
|
+
|
1678
|
+
inline void dpct_free(void *ptr,
|
1679
|
+
const sycl::queue &q)
|
1680
|
+
{
|
1681
|
+
if (ptr)
|
1682
|
+
{
|
1683
|
+
sycl::free(ptr, q.get_context());
|
1684
|
+
}
|
1685
|
+
}
|
1686
|
+
|
1687
|
+
template <typename T>
|
1688
|
+
inline auto get_memory(const void *x)
|
1689
|
+
{
|
1690
|
+
T *new_x = reinterpret_cast<T *>(const_cast<void *>(x));
|
1691
|
+
return new_x;
|
1692
|
+
}
|
1693
|
+
|
1694
|
+
template <typename T>
|
1695
|
+
inline typename DataType<T>::T2 get_value(const T *s, sycl::queue &q)
|
1696
|
+
{
|
1697
|
+
using Ty = typename DataType<T>::T2;
|
1698
|
+
Ty s_h;
|
1699
|
+
if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only)
|
1700
|
+
detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host)
|
1701
|
+
.wait();
|
1702
|
+
else
|
1703
|
+
s_h = *reinterpret_cast<const Ty *>(s);
|
1704
|
+
return s_h;
|
1705
|
+
}
|
1706
|
+
|
1707
|
+
} // namespace detail
|
1708
|
+
|
1709
|
+
template <typename T>
|
1710
|
+
inline auto get_value(const T *s, sycl::queue &q)
|
1711
|
+
{
|
1712
|
+
return detail::get_value(s, q);
|
1713
|
+
}
|
1714
|
+
|
1715
|
+
namespace detail
|
1716
|
+
{
|
1717
|
+
template <class Ta, class Tb, class Tc, class Ts>
|
1718
|
+
inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
1719
|
+
int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
|
1720
|
+
const void * beta, void * c, int ldc) {
|
1721
|
+
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
1722
|
+
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
1723
|
+
auto data_a = get_memory<const Ta>(a);
|
1724
|
+
auto data_b = get_memory<const Tb>(b);
|
1725
|
+
auto data_c = get_memory<Tc>(c);
|
1726
|
+
oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
|
1727
|
+
lda, data_b, ldb, beta_value, data_c, ldc);
|
1728
|
+
}
|
1729
|
+
|
1730
|
+
template <typename VecT, class BinaryOperation, class = void>
|
1731
|
+
class vectorized_binary
|
1732
|
+
{
|
1733
|
+
public:
|
1734
|
+
inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
|
1735
|
+
{
|
1736
|
+
VecT v4;
|
1737
|
+
for (size_t i = 0; i < v4.size(); ++i)
|
1738
|
+
{
|
1739
|
+
v4[i] = binary_op(a[i], b[i]);
|
1740
|
+
}
|
1741
|
+
return v4;
|
1742
|
+
}
|
1743
|
+
};
|
1744
|
+
|
1745
|
+
template <typename VecT, class BinaryOperation>
|
1746
|
+
class vectorized_binary<
|
1747
|
+
VecT, BinaryOperation,
|
1748
|
+
std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>>
|
1749
|
+
{
|
1750
|
+
public:
|
1751
|
+
inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
|
1752
|
+
{
|
1753
|
+
return binary_op(a, b).template as<VecT>();
|
1754
|
+
}
|
1755
|
+
};
|
1756
|
+
|
1757
|
+
template <class Ta, class Tb, class Tc, class Ts>
|
1758
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
1759
|
+
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
1760
|
+
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
1761
|
+
matrix_info_t<float> * matrix_info) {
|
1762
|
+
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
1763
|
+
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
1764
|
+
|
1765
|
+
matrix_info->transpose_info[0] = a_trans;
|
1766
|
+
matrix_info->transpose_info[1] = b_trans;
|
1767
|
+
matrix_info->value_info[0] = alpha_value;
|
1768
|
+
matrix_info->value_info[1] = beta_value;
|
1769
|
+
matrix_info->size_info[0] = m;
|
1770
|
+
matrix_info->size_info[1] = n;
|
1771
|
+
matrix_info->size_info[2] = k;
|
1772
|
+
matrix_info->ld_info[0] = lda;
|
1773
|
+
matrix_info->ld_info[1] = ldb;
|
1774
|
+
matrix_info->ld_info[2] = ldc;
|
1775
|
+
matrix_info->groupsize_info = batch_size;
|
1776
|
+
|
1777
|
+
sycl::event e = oneapi::math::blas::column_major::gemm_batch(
|
1778
|
+
get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
1779
|
+
matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
|
1780
|
+
reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
1781
|
+
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
1782
|
+
reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
|
1783
|
+
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
1784
|
+
}
|
1785
|
+
|
1786
|
+
template <class Ta, class Tb, class Tc, class Ts>
|
1787
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
1788
|
+
int m, int n, int k, const void * alpha, const void * a, int lda,
|
1789
|
+
long long int stride_a, const void * b, int ldb, long long int stride_b,
|
1790
|
+
const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
|
1791
|
+
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
1792
|
+
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
1793
|
+
auto data_a = get_memory<const Ta>(a);
|
1794
|
+
auto data_b = get_memory<const Tb>(b);
|
1795
|
+
auto data_c = get_memory<Tc>(c);
|
1796
|
+
oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
|
1797
|
+
data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
|
1798
|
+
data_c, ldc, stride_c, batch_size);
|
1799
|
+
}
|
1800
|
+
|
1801
|
+
} // namespace detail
|
1802
|
+
|
1803
|
+
template <typename VecT, class BinaryOperation>
|
1804
|
+
inline unsigned vectorized_binary(unsigned a, unsigned b,
|
1805
|
+
const BinaryOperation binary_op)
|
1806
|
+
{
|
1807
|
+
sycl::vec<unsigned, 1> v0{a}, v1{b};
|
1808
|
+
auto v2 = v0.as<VecT>();
|
1809
|
+
auto v3 = v1.as<VecT>();
|
1810
|
+
auto v4 =
|
1811
|
+
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
|
1812
|
+
v0 = v4.template as<sycl::vec<unsigned, 1>>();
|
1813
|
+
return v0;
|
1814
|
+
}
|
1815
|
+
|
1816
|
+
static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size,
|
1817
|
+
memcpy_direction direction = automatic,
|
1818
|
+
sycl::queue &q = dpct::get_default_queue())
|
1819
|
+
{
|
1820
|
+
detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction);
|
1821
|
+
}
|
1822
|
+
|
1823
|
+
static inline unsigned int select_device(unsigned int id)
|
1824
|
+
{
|
1825
|
+
dev_mgr::instance().select_device(id);
|
1826
|
+
return id;
|
1827
|
+
}
|
1828
|
+
|
1829
|
+
template <typename T>
|
1830
|
+
T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
|
1831
|
+
unsigned int logical_sub_group_size = 32)
|
1832
|
+
{
|
1833
|
+
unsigned int id = g.get_local_linear_id();
|
1834
|
+
unsigned int start_index =
|
1835
|
+
id / logical_sub_group_size * logical_sub_group_size;
|
1836
|
+
unsigned int target_offset = (id % logical_sub_group_size) ^ mask;
|
1837
|
+
return sycl::select_from_group(g, x,
|
1838
|
+
target_offset < logical_sub_group_size
|
1839
|
+
? start_index + target_offset
|
1840
|
+
: id);
|
1841
|
+
}
|
1842
|
+
|
1843
|
+
template <typename T1, typename T2, typename T3>
|
1844
|
+
inline auto dp4a(T1 a, T2 b, T3 c)
|
1845
|
+
{
|
1846
|
+
return syclcompat::dp4a(a, b, c);
|
1847
|
+
}
|
1848
|
+
|
1849
|
+
struct sub_sat
|
1850
|
+
{
|
1851
|
+
template <typename T>
|
1852
|
+
auto operator()(const T x, const T y) const
|
1853
|
+
{
|
1854
|
+
return sycl::sub_sat(x, y);
|
1855
|
+
}
|
1856
|
+
};
|
1857
|
+
|
1858
|
+
template <typename S, typename T>
|
1859
|
+
inline T vectorized_min(T a, T b)
|
1860
|
+
{
|
1861
|
+
sycl::vec<T, 1> v0{a}, v1{b};
|
1862
|
+
auto v2 = v0.template as<S>();
|
1863
|
+
auto v3 = v1.template as<S>();
|
1864
|
+
auto v4 = sycl::min(v2, v3);
|
1865
|
+
v0 = v4.template as<sycl::vec<T, 1>>();
|
1866
|
+
return v0;
|
1867
|
+
}
|
1868
|
+
|
1869
|
+
inline float pow(const float a, const int b) { return sycl::pown(a, b); }
|
1870
|
+
inline double pow(const double a, const int b) { return sycl::pown(a, b); }
|
1871
|
+
inline float pow(const float a, const float b) { return sycl::pow(a, b); }
|
1872
|
+
inline double pow(const double a, const double b) { return sycl::pow(a, b); }
|
1873
|
+
template <typename T, typename U>
|
1874
|
+
inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
|
1875
|
+
pow(const T a, const U b)
|
1876
|
+
{
|
1877
|
+
return sycl::pow(a, static_cast<T>(b));
|
1878
|
+
}
|
1879
|
+
template <typename T, typename U>
|
1880
|
+
inline typename std::enable_if_t<!std::is_floating_point_v<T>, double>
|
1881
|
+
pow(const T a, const U b)
|
1882
|
+
{
|
1883
|
+
return sycl::pow(static_cast<double>(a), static_cast<double>(b));
|
1884
|
+
}
|
1885
|
+
|
1886
|
+
inline double min(const double a, const float b)
|
1887
|
+
{
|
1888
|
+
return sycl::fmin(a, static_cast<double>(b));
|
1889
|
+
}
|
1890
|
+
inline double min(const float a, const double b)
|
1891
|
+
{
|
1892
|
+
return sycl::fmin(static_cast<double>(a), b);
|
1893
|
+
}
|
1894
|
+
inline float min(const float a, const float b) { return sycl::fmin(a, b); }
|
1895
|
+
inline double min(const double a, const double b) { return sycl::fmin(a, b); }
|
1896
|
+
inline std::uint32_t min(const std::uint32_t a, const std::int32_t b)
|
1897
|
+
{
|
1898
|
+
return sycl::min(a, static_cast<std::uint32_t>(b));
|
1899
|
+
}
|
1900
|
+
inline std::uint32_t min(const std::int32_t a, const std::uint32_t b)
|
1901
|
+
{
|
1902
|
+
return sycl::min(static_cast<std::uint32_t>(a), b);
|
1903
|
+
}
|
1904
|
+
inline std::int32_t min(const std::int32_t a, const std::int32_t b)
|
1905
|
+
{
|
1906
|
+
return sycl::min(a, b);
|
1907
|
+
}
|
1908
|
+
inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b)
|
1909
|
+
{
|
1910
|
+
return sycl::min(a, b);
|
1911
|
+
}
|
1912
|
+
inline std::uint64_t min(const std::uint64_t a, const std::int64_t b)
|
1913
|
+
{
|
1914
|
+
return sycl::min(a, static_cast<std::uint64_t>(b));
|
1915
|
+
}
|
1916
|
+
inline std::uint64_t min(const std::int64_t a, const std::uint64_t b)
|
1917
|
+
{
|
1918
|
+
return sycl::min(static_cast<std::uint64_t>(a), b);
|
1919
|
+
}
|
1920
|
+
inline std::int64_t min(const std::int64_t a, const std::int64_t b)
|
1921
|
+
{
|
1922
|
+
return sycl::min(a, b);
|
1923
|
+
}
|
1924
|
+
inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b)
|
1925
|
+
{
|
1926
|
+
return sycl::min(a, b);
|
1927
|
+
}
|
1928
|
+
inline std::uint64_t min(const std::uint64_t a, const std::int32_t b)
|
1929
|
+
{
|
1930
|
+
return sycl::min(a, static_cast<std::uint64_t>(b));
|
1931
|
+
}
|
1932
|
+
inline std::uint64_t min(const std::int32_t a, const std::uint64_t b)
|
1933
|
+
{
|
1934
|
+
return sycl::min(static_cast<std::uint64_t>(a), b);
|
1935
|
+
}
|
1936
|
+
inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b)
|
1937
|
+
{
|
1938
|
+
return sycl::min(a, static_cast<std::uint64_t>(b));
|
1939
|
+
}
|
1940
|
+
inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b)
|
1941
|
+
{
|
1942
|
+
return sycl::min(static_cast<std::uint64_t>(a), b);
|
1943
|
+
}
|
1944
|
+
// max function overloads.
|
1945
|
+
// For floating-point types, `float` or `double` arguments are acceptable.
|
1946
|
+
// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
|
1947
|
+
// `std::int64_t` type arguments are acceptable.
|
1948
|
+
inline double max(const double a, const float b)
|
1949
|
+
{
|
1950
|
+
return sycl::fmax(a, static_cast<double>(b));
|
1951
|
+
}
|
1952
|
+
inline double max(const float a, const double b)
|
1953
|
+
{
|
1954
|
+
return sycl::fmax(static_cast<double>(a), b);
|
1955
|
+
}
|
1956
|
+
inline float max(const float a, const float b) { return sycl::fmax(a, b); }
|
1957
|
+
inline double max(const double a, const double b) { return sycl::fmax(a, b); }
|
1958
|
+
inline std::uint32_t max(const std::uint32_t a, const std::int32_t b)
|
1959
|
+
{
|
1960
|
+
return sycl::max(a, static_cast<std::uint32_t>(b));
|
1961
|
+
}
|
1962
|
+
inline std::uint32_t max(const std::int32_t a, const std::uint32_t b)
|
1963
|
+
{
|
1964
|
+
return sycl::max(static_cast<std::uint32_t>(a), b);
|
1965
|
+
}
|
1966
|
+
inline std::int32_t max(const std::int32_t a, const std::int32_t b)
|
1967
|
+
{
|
1968
|
+
return sycl::max(a, b);
|
1969
|
+
}
|
1970
|
+
inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b)
|
1971
|
+
{
|
1972
|
+
return sycl::max(a, b);
|
1973
|
+
}
|
1974
|
+
inline std::uint64_t max(const std::uint64_t a, const std::int64_t b)
|
1975
|
+
{
|
1976
|
+
return sycl::max(a, static_cast<std::uint64_t>(b));
|
1977
|
+
}
|
1978
|
+
inline std::uint64_t max(const std::int64_t a, const std::uint64_t b)
|
1979
|
+
{
|
1980
|
+
return sycl::max(static_cast<std::uint64_t>(a), b);
|
1981
|
+
}
|
1982
|
+
inline std::int64_t max(const std::int64_t a, const std::int64_t b)
|
1983
|
+
{
|
1984
|
+
return sycl::max(a, b);
|
1985
|
+
}
|
1986
|
+
inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b)
|
1987
|
+
{
|
1988
|
+
return sycl::max(a, b);
|
1989
|
+
}
|
1990
|
+
inline std::uint64_t max(const std::uint64_t a, const std::int32_t b)
|
1991
|
+
{
|
1992
|
+
return sycl::max(a, static_cast<std::uint64_t>(b));
|
1993
|
+
}
|
1994
|
+
inline std::uint64_t max(const std::int32_t a, const std::uint64_t b)
|
1995
|
+
{
|
1996
|
+
return sycl::max(static_cast<std::uint64_t>(a), b);
|
1997
|
+
}
|
1998
|
+
inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b)
|
1999
|
+
{
|
2000
|
+
return sycl::max(a, static_cast<std::uint64_t>(b));
|
2001
|
+
}
|
2002
|
+
inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b)
|
2003
|
+
{
|
2004
|
+
return sycl::max(static_cast<std::uint64_t>(a), b);
|
2005
|
+
}
|
2006
|
+
|
2007
|
+
inline void
|
2008
|
+
has_capability_or_fail(const sycl::device &dev,
|
2009
|
+
const std::initializer_list<sycl::aspect> &props)
|
2010
|
+
{
|
2011
|
+
for (const auto &it : props)
|
2012
|
+
{
|
2013
|
+
if (dev.has(it))
|
2014
|
+
continue;
|
2015
|
+
switch (it)
|
2016
|
+
{
|
2017
|
+
case sycl::aspect::fp64:
|
2018
|
+
throw std::runtime_error("'double' is not supported in '" +
|
2019
|
+
dev.get_info<sycl::info::device::name>() +
|
2020
|
+
"' device");
|
2021
|
+
break;
|
2022
|
+
case sycl::aspect::fp16:
|
2023
|
+
throw std::runtime_error("'half' is not supported in '" +
|
2024
|
+
dev.get_info<sycl::info::device::name>() +
|
2025
|
+
"' device");
|
2026
|
+
break;
|
2027
|
+
default:
|
2028
|
+
#define __SYCL_ASPECT(ASPECT, ID) \
|
2029
|
+
case sycl::aspect::ASPECT: \
|
2030
|
+
return #ASPECT;
|
2031
|
+
#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
|
2032
|
+
#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
|
2033
|
+
auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string
|
2034
|
+
{
|
2035
|
+
switch (AspectNum)
|
2036
|
+
{
|
2037
|
+
#include <sycl/info/aspects.def>
|
2038
|
+
#include <sycl/info/aspects_deprecated.def>
|
2039
|
+
default:
|
2040
|
+
return "unknown aspect";
|
2041
|
+
}
|
2042
|
+
};
|
2043
|
+
#undef __SYCL_ASPECT_DEPRECATED_ALIAS
|
2044
|
+
#undef __SYCL_ASPECT_DEPRECATED
|
2045
|
+
#undef __SYCL_ASPECT
|
2046
|
+
throw std::runtime_error(
|
2047
|
+
"'" + getAspectNameStr(it) + "' is not supported in '" +
|
2048
|
+
dev.get_info<sycl::info::device::name>() + "' device");
|
2049
|
+
}
|
2050
|
+
break;
|
2051
|
+
}
|
2052
|
+
}
|
2053
|
+
|
2054
|
+
static inline unsigned int get_current_device_id()
|
2055
|
+
{
|
2056
|
+
return dev_mgr::instance().current_device_id();
|
2057
|
+
}
|
2058
|
+
|
2059
|
+
static inline device_ext &get_current_device()
|
2060
|
+
{
|
2061
|
+
return dev_mgr::instance().current_device();
|
2062
|
+
}
|
2063
|
+
|
2064
|
+
static inline device_ext &get_device(unsigned int id)
|
2065
|
+
{
|
2066
|
+
return dev_mgr::instance().get_device(id);
|
2067
|
+
}
|
2068
|
+
|
2069
|
+
static inline sycl::queue &get_in_order_queue()
|
2070
|
+
{
|
2071
|
+
return dev_mgr::instance().current_device().in_order_queue();
|
2072
|
+
}
|
2073
|
+
|
2074
|
+
static sycl::event
|
2075
|
+
dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
|
2076
|
+
memcpy_direction direction,
|
2077
|
+
const std::vector<sycl::event> &dep_events = {})
|
2078
|
+
{
|
2079
|
+
if (!size)
|
2080
|
+
return sycl::event{};
|
2081
|
+
return q.memcpy(to_ptr, from_ptr, size, dep_events);
|
2082
|
+
GGML_UNUSED(direction);
|
2083
|
+
}
|
2084
|
+
|
2085
|
+
// Get actual copy range and make sure it will not exceed range.
|
2086
|
+
static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
|
2087
|
+
size_t pitch)
|
2088
|
+
{
|
2089
|
+
return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
|
2090
|
+
}
|
2091
|
+
|
2092
|
+
static inline size_t get_offset(sycl::id<3> id, size_t slice,
|
2093
|
+
size_t pitch)
|
2094
|
+
{
|
2095
|
+
return slice * id.get(2) + pitch * id.get(1) + id.get(0);
|
2096
|
+
}
|
2097
|
+
|
2098
|
+
/// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
|
2099
|
+
/// and \p from_range to another specified by \p to_ptr and \p to_range.
|
2100
|
+
static inline std::vector<sycl::event>
|
2101
|
+
dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
|
2102
|
+
sycl::range<3> to_range, sycl::range<3> from_range,
|
2103
|
+
sycl::id<3> to_id, sycl::id<3> from_id,
|
2104
|
+
sycl::range<3> size, memcpy_direction direction,
|
2105
|
+
const std::vector<sycl::event> &dep_events = {})
|
2106
|
+
{
|
2107
|
+
// RAII for host pointer
|
2108
|
+
class host_buffer
|
2109
|
+
{
|
2110
|
+
void *_buf;
|
2111
|
+
size_t _size;
|
2112
|
+
sycl::queue &_q;
|
2113
|
+
const std::vector<sycl::event> &_deps; // free operation depends
|
2114
|
+
|
2115
|
+
public:
|
2116
|
+
host_buffer(size_t size, sycl::queue &q,
|
2117
|
+
const std::vector<sycl::event> &deps)
|
2118
|
+
: _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
|
2119
|
+
void *get_ptr() const { return _buf; }
|
2120
|
+
size_t get_size() const { return _size; }
|
2121
|
+
~host_buffer()
|
2122
|
+
{
|
2123
|
+
if (_buf)
|
2124
|
+
{
|
2125
|
+
_q.submit([&](sycl::handler &cgh)
|
2126
|
+
{
|
2127
|
+
cgh.depends_on(_deps);
|
2128
|
+
cgh.host_task([buf = _buf] { std::free(buf); }); });
|
2129
|
+
}
|
2130
|
+
}
|
2131
|
+
};
|
2132
|
+
std::vector<sycl::event> event_list;
|
2133
|
+
|
2134
|
+
size_t to_slice = to_range.get(1) * to_range.get(0),
|
2135
|
+
from_slice = from_range.get(1) * from_range.get(0);
|
2136
|
+
unsigned char *to_surface =
|
2137
|
+
(unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
|
2138
|
+
const unsigned char *from_surface =
|
2139
|
+
(const unsigned char *)from_ptr +
|
2140
|
+
get_offset(from_id, from_slice, from_range.get(0));
|
2141
|
+
|
2142
|
+
if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
|
2143
|
+
{
|
2144
|
+
return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
|
2145
|
+
direction, dep_events)};
|
2146
|
+
}
|
2147
|
+
direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
|
2148
|
+
size_t size_slice = size.get(1) * size.get(0);
|
2149
|
+
switch (direction)
|
2150
|
+
{
|
2151
|
+
case host_to_host:
|
2152
|
+
for (size_t z = 0; z < size.get(2); ++z)
|
2153
|
+
{
|
2154
|
+
unsigned char *to_ptr = to_surface;
|
2155
|
+
const unsigned char *from_ptr = from_surface;
|
2156
|
+
if (to_range.get(0) == from_range.get(0) &&
|
2157
|
+
to_range.get(0) == size.get(0))
|
2158
|
+
{
|
2159
|
+
event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
|
2160
|
+
direction, dep_events));
|
2161
|
+
}
|
2162
|
+
else
|
2163
|
+
{
|
2164
|
+
for (size_t y = 0; y < size.get(1); ++y)
|
2165
|
+
{
|
2166
|
+
event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
|
2167
|
+
direction, dep_events));
|
2168
|
+
to_ptr += to_range.get(0);
|
2169
|
+
from_ptr += from_range.get(0);
|
2170
|
+
}
|
2171
|
+
}
|
2172
|
+
to_surface += to_slice;
|
2173
|
+
from_surface += from_slice;
|
2174
|
+
}
|
2175
|
+
break;
|
2176
|
+
case host_to_device:
|
2177
|
+
{
|
2178
|
+
host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
|
2179
|
+
event_list);
|
2180
|
+
std::vector<sycl::event> host_events;
|
2181
|
+
if (to_slice == size_slice)
|
2182
|
+
{
|
2183
|
+
// Copy host data to a temp host buffer with the shape of target.
|
2184
|
+
host_events =
|
2185
|
+
dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
|
2186
|
+
sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
|
2187
|
+
host_to_host, dep_events);
|
2188
|
+
}
|
2189
|
+
else
|
2190
|
+
{
|
2191
|
+
// Copy host data to a temp host buffer with the shape of target.
|
2192
|
+
host_events = dpct_memcpy(
|
2193
|
+
q, buf.get_ptr(), from_surface, to_range, from_range,
|
2194
|
+
sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
|
2195
|
+
// If has padding data, not sure whether it is useless. So fill temp
|
2196
|
+
// buffer with it.
|
2197
|
+
std::vector<sycl::event>{
|
2198
|
+
dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
|
2199
|
+
device_to_host, dep_events)});
|
2200
|
+
}
|
2201
|
+
// Copy from temp host buffer to device with only one submit.
|
2202
|
+
event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
|
2203
|
+
buf.get_size(), host_to_device,
|
2204
|
+
host_events));
|
2205
|
+
break;
|
2206
|
+
}
|
2207
|
+
case device_to_host:
|
2208
|
+
{
|
2209
|
+
host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
|
2210
|
+
event_list);
|
2211
|
+
// Copy from host temp buffer to host target with reshaping.
|
2212
|
+
event_list = dpct_memcpy(
|
2213
|
+
q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
|
2214
|
+
sycl::id<3>(0, 0, 0), size, host_to_host,
|
2215
|
+
// Copy from device to temp host buffer with only one submit.
|
2216
|
+
std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
|
2217
|
+
buf.get_size(),
|
2218
|
+
device_to_host, dep_events)});
|
2219
|
+
break;
|
2220
|
+
}
|
2221
|
+
case device_to_device:
|
2222
|
+
event_list.push_back(q.submit([&](sycl::handler &cgh)
|
2223
|
+
{
|
2224
|
+
cgh.depends_on(dep_events);
|
2225
|
+
cgh.parallel_for<class dpct_memcpy_3d_detail>(
|
2226
|
+
size,
|
2227
|
+
[=](sycl::id<3> id) {
|
2228
|
+
to_surface[get_offset(id, to_slice, to_range.get(0))] =
|
2229
|
+
from_surface[get_offset(id, from_slice, from_range.get(0))];
|
2230
|
+
}); }));
|
2231
|
+
break;
|
2232
|
+
default:
|
2233
|
+
throw std::runtime_error("dpct_memcpy: invalid direction value");
|
2234
|
+
}
|
2235
|
+
return event_list;
|
2236
|
+
}
|
2237
|
+
|
2238
|
+
/// memcpy 2D/3D matrix specified by pitched_data.
|
2239
|
+
static inline std::vector<sycl::event>
|
2240
|
+
dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
|
2241
|
+
pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
|
2242
|
+
memcpy_direction direction = automatic)
|
2243
|
+
{
|
2244
|
+
return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
|
2245
|
+
sycl::range<3>(to.get_pitch(), to.get_y(), 1),
|
2246
|
+
sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
|
2247
|
+
size, direction);
|
2248
|
+
}
|
2249
|
+
|
2250
|
+
/// memcpy 2D matrix with pitch.
|
2251
|
+
static inline std::vector<sycl::event>
|
2252
|
+
dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
|
2253
|
+
size_t to_pitch, size_t from_pitch, size_t x, size_t y,
|
2254
|
+
memcpy_direction direction = automatic)
|
2255
|
+
{
|
2256
|
+
return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
|
2257
|
+
sycl::range<3>(from_pitch, y, 1),
|
2258
|
+
sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
|
2259
|
+
sycl::range<3>(x, y, 1), direction);
|
2260
|
+
}
|
2261
|
+
|
2262
|
+
inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
|
2263
|
+
int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
|
2264
|
+
library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
|
2265
|
+
library_data_t scaling_type) {
|
2266
|
+
if (scaling_type == library_data_t::real_float &&
|
2267
|
+
c_type == library_data_t::complex_float)
|
2268
|
+
{
|
2269
|
+
scaling_type = library_data_t::complex_float;
|
2270
|
+
}
|
2271
|
+
else if (scaling_type == library_data_t::real_double &&
|
2272
|
+
c_type == library_data_t::complex_double)
|
2273
|
+
{
|
2274
|
+
scaling_type = library_data_t::complex_double;
|
2275
|
+
}
|
2276
|
+
|
2277
|
+
std::uint64_t key =
|
2278
|
+
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
2279
|
+
switch (key)
|
2280
|
+
{
|
2281
|
+
case detail::get_type_combination_id(
|
2282
|
+
library_data_t::real_float, library_data_t::real_float,
|
2283
|
+
library_data_t::real_float, library_data_t::real_float):
|
2284
|
+
{
|
2285
|
+
detail::gemm_impl<float, float, float, float>(
|
2286
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
2287
|
+
break;
|
2288
|
+
}
|
2289
|
+
case detail::get_type_combination_id(
|
2290
|
+
library_data_t::real_double, library_data_t::real_double,
|
2291
|
+
library_data_t::real_double, library_data_t::real_double):
|
2292
|
+
{
|
2293
|
+
detail::gemm_impl<double, double, double, double>(
|
2294
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
2295
|
+
break;
|
2296
|
+
}
|
2297
|
+
case detail::get_type_combination_id(
|
2298
|
+
library_data_t::complex_float, library_data_t::complex_float,
|
2299
|
+
library_data_t::complex_float, library_data_t::complex_float):
|
2300
|
+
{
|
2301
|
+
detail::gemm_impl<std::complex<float>, std::complex<float>,
|
2302
|
+
std::complex<float>, std::complex<float>>(
|
2303
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
2304
|
+
break;
|
2305
|
+
}
|
2306
|
+
case detail::get_type_combination_id(
|
2307
|
+
library_data_t::complex_double, library_data_t::complex_double,
|
2308
|
+
library_data_t::complex_double, library_data_t::complex_double):
|
2309
|
+
{
|
2310
|
+
detail::gemm_impl<std::complex<double>, std::complex<double>,
|
2311
|
+
std::complex<double>, std::complex<double>>(
|
2312
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
2313
|
+
break;
|
2314
|
+
}
|
2315
|
+
case detail::get_type_combination_id(
|
2316
|
+
library_data_t::real_half, library_data_t::real_half,
|
2317
|
+
library_data_t::real_half, library_data_t::real_half):
|
2318
|
+
{
|
2319
|
+
detail::gemm_impl<sycl::half, sycl::half, sycl::half,
|
2320
|
+
sycl::half>(q, a_trans, b_trans, m, n, k, alpha, a,
|
2321
|
+
lda, b, ldb, beta, c, ldc);
|
2322
|
+
break;
|
2323
|
+
}
|
2324
|
+
#ifdef __INTEL_MKL__
|
2325
|
+
case detail::get_type_combination_id(
|
2326
|
+
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
2327
|
+
library_data_t::real_float, library_data_t::real_float):
|
2328
|
+
{
|
2329
|
+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
2330
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
2331
|
+
break;
|
2332
|
+
}
|
2333
|
+
case detail::get_type_combination_id(
|
2334
|
+
library_data_t::real_half, library_data_t::real_half,
|
2335
|
+
library_data_t::real_float, library_data_t::real_float):
|
2336
|
+
{
|
2337
|
+
detail::gemm_impl<sycl::half, sycl::half, float, float>(
|
2338
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
2339
|
+
break;
|
2340
|
+
}
|
2341
|
+
case detail::get_type_combination_id(
|
2342
|
+
library_data_t::real_half, library_data_t::real_half,
|
2343
|
+
library_data_t::real_half, library_data_t::real_float):
|
2344
|
+
{
|
2345
|
+
float alpha_value =
|
2346
|
+
dpct::get_value(reinterpret_cast<const float *>(alpha), q);
|
2347
|
+
float beta_value =
|
2348
|
+
dpct::get_value(reinterpret_cast<const float *>(beta), q);
|
2349
|
+
sycl::half alpha_half(alpha_value);
|
2350
|
+
sycl::half beta_half(beta_value);
|
2351
|
+
detail::gemm_impl<sycl::half, sycl::half, sycl::half,
|
2352
|
+
sycl::half>(q, a_trans, b_trans, m, n, k, &alpha_half,
|
2353
|
+
a, lda, b, ldb, &beta_half, c, ldc);
|
2354
|
+
break;
|
2355
|
+
}
|
2356
|
+
case detail::get_type_combination_id(
|
2357
|
+
library_data_t::real_int8, library_data_t::real_int8,
|
2358
|
+
library_data_t::real_float, library_data_t::real_float):
|
2359
|
+
{
|
2360
|
+
detail::gemm_impl<std::int8_t, std::int8_t, float, float>(
|
2361
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
2362
|
+
break;
|
2363
|
+
}
|
2364
|
+
case detail::get_type_combination_id(
|
2365
|
+
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
2366
|
+
library_data_t::real_bfloat16, library_data_t::real_float):
|
2367
|
+
{
|
2368
|
+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
2369
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
2370
|
+
break;
|
2371
|
+
}
|
2372
|
+
case detail::get_type_combination_id(
|
2373
|
+
library_data_t::real_int8, library_data_t::real_int8,
|
2374
|
+
library_data_t::real_int32, library_data_t::real_int32):
|
2375
|
+
{
|
2376
|
+
float alpha_float =
|
2377
|
+
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
2378
|
+
float beta_float =
|
2379
|
+
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
2380
|
+
detail::gemm_impl<std::int8_t, std::int8_t, std::int32_t, float>(
|
2381
|
+
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
|
2382
|
+
break;
|
2383
|
+
}
|
2384
|
+
#endif // __INTEL_MKL__
|
2385
|
+
default:
|
2386
|
+
throw std::runtime_error("the combination of data type is unsupported");
|
2387
|
+
}
|
2388
|
+
} // gemm()
|
2389
|
+
|
2390
|
+
/// Computes a batch of matrix-matrix product with general matrices.
|
2391
|
+
/// \param [in] q The queue where the routine should be executed.
|
2392
|
+
/// \param [in] a_trans Specifies the operation applied to A.
|
2393
|
+
/// \param [in] b_trans Specifies the operation applied to B.
|
2394
|
+
/// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
|
2395
|
+
/// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
|
2396
|
+
/// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
|
2397
|
+
/// \param [in] alpha Scaling factor for the matrix-matrix product.
|
2398
|
+
/// \param [in] a Input matrix A.
|
2399
|
+
/// \param [in] a_type Data type of the matrix A.
|
2400
|
+
/// \param [in] lda Leading dimension of A.
|
2401
|
+
/// \param [in] b Input matrix B.
|
2402
|
+
/// \param [in] b_type Data type of the matrix B.
|
2403
|
+
/// \param [in] ldb Leading dimension of B.
|
2404
|
+
/// \param [in] beta Scaling factor for matrix C.
|
2405
|
+
/// \param [in, out] c Input/Output matrix C.
|
2406
|
+
/// \param [in] c_type Data type of the matrix C.
|
2407
|
+
/// \param [in] ldc Leading dimension of C.
|
2408
|
+
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
2409
|
+
/// \param [in] scaling_type Data type of the scaling factors.
|
2410
|
+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
2411
|
+
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
2412
|
+
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
2413
|
+
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
2414
|
+
matrix_info_t<float> * matrix_info) {
|
2415
|
+
std::uint64_t key =
|
2416
|
+
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
2417
|
+
switch (key)
|
2418
|
+
{
|
2419
|
+
case detail::get_type_combination_id(
|
2420
|
+
library_data_t::real_float, library_data_t::real_float,
|
2421
|
+
library_data_t::real_float, library_data_t::real_float):
|
2422
|
+
{
|
2423
|
+
detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
2424
|
+
beta, c, ldc, batch_size, matrix_info);
|
2425
|
+
break;
|
2426
|
+
}
|
2427
|
+
case detail::get_type_combination_id(
|
2428
|
+
library_data_t::real_double, library_data_t::real_double,
|
2429
|
+
library_data_t::real_double, library_data_t::real_double):
|
2430
|
+
{
|
2431
|
+
detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
2432
|
+
beta, c, ldc, batch_size, matrix_info);
|
2433
|
+
break;
|
2434
|
+
}
|
2435
|
+
case detail::get_type_combination_id(
|
2436
|
+
library_data_t::real_half, library_data_t::real_half,
|
2437
|
+
library_data_t::real_half, library_data_t::real_half):
|
2438
|
+
{
|
2439
|
+
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
2440
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
2441
|
+
break;
|
2442
|
+
}
|
2443
|
+
#ifdef __INTEL_MKL__
|
2444
|
+
case detail::get_type_combination_id(
|
2445
|
+
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
2446
|
+
library_data_t::real_bfloat16, library_data_t::real_float):
|
2447
|
+
{
|
2448
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
2449
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
2450
|
+
break;
|
2451
|
+
}
|
2452
|
+
case detail::get_type_combination_id(
|
2453
|
+
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
2454
|
+
library_data_t::real_float, library_data_t::real_float):
|
2455
|
+
{
|
2456
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
2457
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
2458
|
+
break;
|
2459
|
+
}
|
2460
|
+
#endif
|
2461
|
+
case detail::get_type_combination_id(
|
2462
|
+
library_data_t::real_int8, library_data_t::real_int8,
|
2463
|
+
library_data_t::real_int32, library_data_t::real_int32):
|
2464
|
+
{
|
2465
|
+
float alpha_float =
|
2466
|
+
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
2467
|
+
float beta_float =
|
2468
|
+
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
2469
|
+
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
|
2470
|
+
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
|
2471
|
+
matrix_info);
|
2472
|
+
break;
|
2473
|
+
}
|
2474
|
+
case detail::get_type_combination_id(
|
2475
|
+
library_data_t::real_int8, library_data_t::real_int8,
|
2476
|
+
library_data_t::real_float, library_data_t::real_float):
|
2477
|
+
{
|
2478
|
+
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
2479
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
2480
|
+
break;
|
2481
|
+
}
|
2482
|
+
case detail::get_type_combination_id(
|
2483
|
+
library_data_t::real_half, library_data_t::real_half,
|
2484
|
+
library_data_t::real_float, library_data_t::real_float):
|
2485
|
+
{
|
2486
|
+
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
2487
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
2488
|
+
break;
|
2489
|
+
}
|
2490
|
+
case detail::get_type_combination_id(
|
2491
|
+
library_data_t::real_half, library_data_t::real_half,
|
2492
|
+
library_data_t::real_half, library_data_t::real_float):
|
2493
|
+
{
|
2494
|
+
float alpha_value =
|
2495
|
+
dpct::get_value(reinterpret_cast<const float *>(alpha), q);
|
2496
|
+
float beta_value =
|
2497
|
+
dpct::get_value(reinterpret_cast<const float *>(beta), q);
|
2498
|
+
sycl::half alpha_half(alpha_value);
|
2499
|
+
sycl::half beta_half(beta_value);
|
2500
|
+
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
2501
|
+
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
|
2502
|
+
break;
|
2503
|
+
}
|
2504
|
+
default:
|
2505
|
+
throw std::runtime_error("the combination of data type is unsupported");
|
2506
|
+
}
|
2507
|
+
}
|
2508
|
+
|
2509
|
+
/// Computes a batch of matrix-matrix product with general matrices.
|
2510
|
+
/// \param [in] q The queue where the routine should be executed.
|
2511
|
+
/// \param [in] a_trans Specifies the operation applied to A.
|
2512
|
+
/// \param [in] b_trans Specifies the operation applied to B.
|
2513
|
+
/// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
|
2514
|
+
/// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
|
2515
|
+
/// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
|
2516
|
+
/// \param [in] alpha Scaling factor for the matrix-matrix product.
|
2517
|
+
/// \param [in] a Input matrix A.
|
2518
|
+
/// \param [in] a_type Data type of the matrix A.
|
2519
|
+
/// \param [in] lda Leading dimension of A.
|
2520
|
+
/// \param [in] stride_a Stride between the different A matrices.
|
2521
|
+
/// \param [in] b Input matrix B.
|
2522
|
+
/// \param [in] b_type Data type of the matrix B.
|
2523
|
+
/// \param [in] ldb Leading dimension of B.
|
2524
|
+
/// \param [in] stride_b Stride between the different B matrices.
|
2525
|
+
/// \param [in] beta Scaling factor for matrix C.
|
2526
|
+
/// \param [in, out] c Input/Output matrix C.
|
2527
|
+
/// \param [in] c_type Data type of the matrix C.
|
2528
|
+
/// \param [in] ldc Leading dimension of C.
|
2529
|
+
/// \param [in] stride_c Stride between the different C matrices.
|
2530
|
+
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
2531
|
+
/// \param [in] scaling_type Data type of the scaling factors.
|
2532
|
+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
2533
|
+
int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
|
2534
|
+
long long int stride_a, const void * b, library_data_t b_type, int ldb,
|
2535
|
+
long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
|
2536
|
+
long long int stride_c, int batch_size, library_data_t scaling_type) {
|
2537
|
+
if (scaling_type == library_data_t::real_float &&
|
2538
|
+
c_type == library_data_t::complex_float)
|
2539
|
+
{
|
2540
|
+
scaling_type = library_data_t::complex_float;
|
2541
|
+
}
|
2542
|
+
else if (scaling_type == library_data_t::real_double &&
|
2543
|
+
c_type == library_data_t::complex_double)
|
2544
|
+
{
|
2545
|
+
scaling_type = library_data_t::complex_double;
|
2546
|
+
}
|
2547
|
+
|
2548
|
+
std::uint64_t key =
|
2549
|
+
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
2550
|
+
switch (key)
|
2551
|
+
{
|
2552
|
+
case detail::get_type_combination_id(
|
2553
|
+
library_data_t::real_float, library_data_t::real_float,
|
2554
|
+
library_data_t::real_float, library_data_t::real_float):
|
2555
|
+
{
|
2556
|
+
detail::gemm_batch_impl<float, float, float, float>(
|
2557
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
|
2558
|
+
beta, c, ldc, stride_c, batch_size);
|
2559
|
+
break;
|
2560
|
+
}
|
2561
|
+
case detail::get_type_combination_id(
|
2562
|
+
library_data_t::real_double, library_data_t::real_double,
|
2563
|
+
library_data_t::real_double, library_data_t::real_double):
|
2564
|
+
{
|
2565
|
+
detail::gemm_batch_impl<double, double, double, double>(
|
2566
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
|
2567
|
+
beta, c, ldc, stride_c, batch_size);
|
2568
|
+
break;
|
2569
|
+
}
|
2570
|
+
case detail::get_type_combination_id(
|
2571
|
+
library_data_t::complex_float, library_data_t::complex_float,
|
2572
|
+
library_data_t::complex_float, library_data_t::complex_float):
|
2573
|
+
{
|
2574
|
+
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
|
2575
|
+
std::complex<float>, std::complex<float>>(
|
2576
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
|
2577
|
+
beta, c, ldc, stride_c, batch_size);
|
2578
|
+
break;
|
2579
|
+
}
|
2580
|
+
case detail::get_type_combination_id(
|
2581
|
+
library_data_t::complex_double, library_data_t::complex_double,
|
2582
|
+
library_data_t::complex_double, library_data_t::complex_double):
|
2583
|
+
{
|
2584
|
+
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
|
2585
|
+
std::complex<double>, std::complex<double>>(
|
2586
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
|
2587
|
+
beta, c, ldc, stride_c, batch_size);
|
2588
|
+
break;
|
2589
|
+
}
|
2590
|
+
case detail::get_type_combination_id(
|
2591
|
+
library_data_t::real_half, library_data_t::real_half,
|
2592
|
+
library_data_t::real_half, library_data_t::real_half):
|
2593
|
+
{
|
2594
|
+
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
|
2595
|
+
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
|
2596
|
+
a, lda, stride_a, b, ldb, stride_b,
|
2597
|
+
beta, c, ldc, stride_c, batch_size);
|
2598
|
+
break;
|
2599
|
+
}
|
2600
|
+
#ifdef __INTEL_MKL__
|
2601
|
+
case detail::get_type_combination_id(
|
2602
|
+
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
2603
|
+
library_data_t::real_bfloat16, library_data_t::real_float):
|
2604
|
+
{
|
2605
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
2606
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
2607
|
+
batch_size);
|
2608
|
+
break;
|
2609
|
+
}
|
2610
|
+
case detail::get_type_combination_id(
|
2611
|
+
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
2612
|
+
library_data_t::real_float, library_data_t::real_float):
|
2613
|
+
{
|
2614
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
2615
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
2616
|
+
batch_size);
|
2617
|
+
break;
|
2618
|
+
}
|
2619
|
+
#endif
|
2620
|
+
case detail::get_type_combination_id(
|
2621
|
+
library_data_t::real_int8, library_data_t::real_int8,
|
2622
|
+
library_data_t::real_int32, library_data_t::real_int32):
|
2623
|
+
{
|
2624
|
+
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
|
2625
|
+
std::int32_t>(q, a_trans, b_trans, m, n, k, alpha,
|
2626
|
+
a, lda, stride_a, b, ldb, stride_b,
|
2627
|
+
beta, c, ldc, stride_c, batch_size);
|
2628
|
+
break;
|
2629
|
+
}
|
2630
|
+
case detail::get_type_combination_id(
|
2631
|
+
library_data_t::real_int8, library_data_t::real_int8,
|
2632
|
+
library_data_t::real_float, library_data_t::real_float):
|
2633
|
+
{
|
2634
|
+
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
2635
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
|
2636
|
+
beta, c, ldc, stride_c, batch_size);
|
2637
|
+
break;
|
2638
|
+
}
|
2639
|
+
case detail::get_type_combination_id(
|
2640
|
+
library_data_t::real_half, library_data_t::real_half,
|
2641
|
+
library_data_t::real_float, library_data_t::real_float):
|
2642
|
+
{
|
2643
|
+
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
2644
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
|
2645
|
+
beta, c, ldc, stride_c, batch_size);
|
2646
|
+
break;
|
2647
|
+
}
|
2648
|
+
case detail::get_type_combination_id(
|
2649
|
+
library_data_t::real_half, library_data_t::real_half,
|
2650
|
+
library_data_t::real_half, library_data_t::real_float):
|
2651
|
+
{
|
2652
|
+
float alpha_value =
|
2653
|
+
dpct::get_value(reinterpret_cast<const float *>(alpha), q);
|
2654
|
+
float beta_value =
|
2655
|
+
dpct::get_value(reinterpret_cast<const float *>(beta), q);
|
2656
|
+
sycl::half alpha_half(alpha_value);
|
2657
|
+
sycl::half beta_half(beta_value);
|
2658
|
+
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
2659
|
+
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b,
|
2660
|
+
&beta_half, c, ldc, stride_c, batch_size);
|
2661
|
+
break;
|
2662
|
+
}
|
2663
|
+
default:
|
2664
|
+
throw std::runtime_error("the combination of data type is unsupported");
|
2665
|
+
}
|
2666
|
+
}
|
2667
|
+
|
2668
|
+
static inline void
|
2669
|
+
async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr,
|
2670
|
+
size_t from_pitch, size_t x, size_t y,
|
2671
|
+
memcpy_direction direction = automatic,
|
2672
|
+
sycl::queue &q = get_default_queue())
|
2673
|
+
{
|
2674
|
+
detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y,
|
2675
|
+
direction);
|
2676
|
+
}
|
2677
|
+
|
2678
|
+
using err0 = detail::generic_error_type<struct err0_tag, int>;
|
2679
|
+
using err1 = detail::generic_error_type<struct err1_tag, int>;
|
2680
|
+
|
2681
|
+
static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) {
|
2682
|
+
detail::dpct_free(ptr, q);
|
2683
|
+
}
|
2684
|
+
|
2685
|
+
/// dpct accessor used as device function parameter.
|
2686
|
+
template <class T, memory_region Memory, size_t Dimension> class accessor;
|
2687
|
+
template <class T, memory_region Memory> class accessor<T, Memory, 3> {
|
2688
|
+
public:
|
2689
|
+
using memory_t = detail::memory_traits<Memory, T>;
|
2690
|
+
using element_t = typename memory_t::element_t;
|
2691
|
+
using pointer_t = typename memory_t::pointer_t;
|
2692
|
+
using accessor_t = typename memory_t::template accessor_t<3>;
|
2693
|
+
accessor(pointer_t data, const sycl::range<3> &in_range)
|
2694
|
+
: _data(data), _range(in_range) {}
|
2695
|
+
template <memory_region M = Memory>
|
2696
|
+
accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
|
2697
|
+
: accessor(acc, acc.get_range()) {}
|
2698
|
+
accessor(const accessor_t &acc, const sycl::range<3> &in_range)
|
2699
|
+
: accessor(acc.get_pointer(), in_range) {}
|
2700
|
+
accessor<T, Memory, 2> operator[](size_t index) const {
|
2701
|
+
sycl::range<2> sub(_range.get(1), _range.get(2));
|
2702
|
+
return accessor<T, Memory, 2>(_data + index * sub.size(), sub);
|
2703
|
+
}
|
2704
|
+
|
2705
|
+
pointer_t get_ptr() const { return _data; }
|
2706
|
+
|
2707
|
+
private:
|
2708
|
+
pointer_t _data;
|
2709
|
+
sycl::range<3> _range;
|
2710
|
+
};
|
2711
|
+
template <class T, memory_region Memory> class accessor<T, Memory, 2> {
|
2712
|
+
public:
|
2713
|
+
using memory_t = detail::memory_traits<Memory, T>;
|
2714
|
+
using element_t = typename memory_t::element_t;
|
2715
|
+
using pointer_t = typename memory_t::pointer_t;
|
2716
|
+
using accessor_t = typename memory_t::template accessor_t<2>;
|
2717
|
+
accessor(pointer_t data, const sycl::range<2> &in_range)
|
2718
|
+
: _data(data), _range(in_range) {}
|
2719
|
+
template <memory_region M = Memory>
|
2720
|
+
accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
|
2721
|
+
: accessor(acc, acc.get_range()) {}
|
2722
|
+
accessor(const accessor_t &acc, const sycl::range<2> &in_range)
|
2723
|
+
: accessor(acc.get_pointer(), in_range) {}
|
2724
|
+
|
2725
|
+
pointer_t operator[](size_t index) const {
|
2726
|
+
return _data + _range.get(1) * index;
|
2727
|
+
}
|
2728
|
+
|
2729
|
+
pointer_t get_ptr() const { return _data; }
|
2730
|
+
|
2731
|
+
private:
|
2732
|
+
pointer_t _data;
|
2733
|
+
sycl::range<2> _range;
|
2734
|
+
};
|
2735
|
+
|
2736
|
+
namespace detail {
|
2737
|
+
/// Device variable with address space of shared, global or constant.
|
2738
|
+
template <class T, memory_region Memory, size_t Dimension> class device_memory {
|
2739
|
+
public:
|
2740
|
+
using accessor_t =
|
2741
|
+
typename detail::memory_traits<Memory,
|
2742
|
+
T>::template accessor_t<Dimension>;
|
2743
|
+
using value_t = typename detail::memory_traits<Memory, T>::value_t;
|
2744
|
+
using dpct_accessor_t = dpct::accessor<T, Memory, Dimension>;
|
2745
|
+
|
2746
|
+
device_memory() : device_memory(sycl::range<Dimension>(1)) {}
|
2747
|
+
|
2748
|
+
/// Constructor of 1-D array with initializer list
|
2749
|
+
device_memory(const sycl::range<Dimension> &in_range,
|
2750
|
+
std::initializer_list<value_t> &&init_list)
|
2751
|
+
: device_memory(in_range) {
|
2752
|
+
assert(init_list.size() <= in_range.size());
|
2753
|
+
_host_ptr = (value_t *)std::malloc(_size);
|
2754
|
+
std::memset(_host_ptr, 0, _size);
|
2755
|
+
std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T));
|
2756
|
+
}
|
2757
|
+
|
2758
|
+
/// Constructor of 2-D array with initializer list
|
2759
|
+
template <size_t D = Dimension>
|
2760
|
+
device_memory(
|
2761
|
+
const typename std::enable_if<D == 2, sycl::range<2>>::type &in_range,
|
2762
|
+
std::initializer_list<std::initializer_list<value_t>> &&init_list)
|
2763
|
+
: device_memory(in_range) {
|
2764
|
+
assert(init_list.size() <= in_range[0]);
|
2765
|
+
_host_ptr = (value_t *)std::malloc(_size);
|
2766
|
+
std::memset(_host_ptr, 0, _size);
|
2767
|
+
auto tmp_data = _host_ptr;
|
2768
|
+
for (auto sub_list : init_list) {
|
2769
|
+
assert(sub_list.size() <= in_range[1]);
|
2770
|
+
std::memcpy(tmp_data, sub_list.begin(),
|
2771
|
+
sub_list.size() * sizeof(T));
|
2772
|
+
tmp_data += in_range[1];
|
2773
|
+
}
|
2774
|
+
}
|
2775
|
+
|
2776
|
+
/// Constructor with range
|
2777
|
+
device_memory(const sycl::range<Dimension> &range_in)
|
2778
|
+
: _size(range_in.size() * sizeof(T)), _range(range_in),
|
2779
|
+
_reference(false), _host_ptr(nullptr), _device_ptr(nullptr) {
|
2780
|
+
static_assert(
|
2781
|
+
(Memory == global) || (Memory == constant) || (Memory == shared),
|
2782
|
+
"device memory region should be global, constant or shared");
|
2783
|
+
// Make sure that singleton class mem_mgr and dev_mgr will destruct
|
2784
|
+
// later than this.
|
2785
|
+
detail::mem_mgr::instance();
|
2786
|
+
dev_mgr::instance();
|
2787
|
+
}
|
2788
|
+
|
2789
|
+
/// Constructor with range
|
2790
|
+
template <class... Args>
|
2791
|
+
device_memory(Args... Arguments)
|
2792
|
+
: device_memory(sycl::range<Dimension>(Arguments...)) {}
|
2793
|
+
|
2794
|
+
~device_memory() {
|
2795
|
+
if (_device_ptr && !_reference)
|
2796
|
+
dpct::dpct_free(_device_ptr);
|
2797
|
+
if (_host_ptr)
|
2798
|
+
std::free(_host_ptr);
|
2799
|
+
}
|
2800
|
+
|
2801
|
+
/// Allocate memory with default queue, and init memory if has initial
|
2802
|
+
/// value.
|
2803
|
+
void init() { init(dpct::get_default_queue()); }
|
2804
|
+
/// Allocate memory with specified queue, and init memory if has initial
|
2805
|
+
/// value.
|
2806
|
+
void init(sycl::queue &q) {
|
2807
|
+
if (_device_ptr)
|
2808
|
+
return;
|
2809
|
+
if (!_size)
|
2810
|
+
return;
|
2811
|
+
allocate_device(q);
|
2812
|
+
if (_host_ptr)
|
2813
|
+
detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size,
|
2814
|
+
host_to_device);
|
2815
|
+
}
|
2816
|
+
|
2817
|
+
/// The variable is assigned to a device pointer.
|
2818
|
+
void assign(value_t *src, size_t size) {
|
2819
|
+
this->~device_memory();
|
2820
|
+
new (this) device_memory(src, size);
|
2821
|
+
}
|
2822
|
+
|
2823
|
+
/// Get memory pointer of the memory object, which is virtual pointer when
|
2824
|
+
/// usm is not used, and device pointer when usm is used.
|
2825
|
+
value_t *get_ptr() { return get_ptr(get_default_queue()); }
|
2826
|
+
/// Get memory pointer of the memory object, which is virtual pointer when
|
2827
|
+
/// usm is not used, and device pointer when usm is used.
|
2828
|
+
value_t *get_ptr(sycl::queue &q) {
|
2829
|
+
init(q);
|
2830
|
+
return _device_ptr;
|
2831
|
+
}
|
2832
|
+
|
2833
|
+
/// Get the device memory object size in bytes.
|
2834
|
+
size_t get_size() { return _size; }
|
2835
|
+
|
2836
|
+
template <size_t D = Dimension>
|
2837
|
+
typename std::enable_if<D == 1, T>::type &operator[](size_t index) {
|
2838
|
+
init();
|
2839
|
+
return _device_ptr[index];
|
2840
|
+
}
|
2841
|
+
|
2842
|
+
/// Get dpct::accessor with dimension info for the device memory object
|
2843
|
+
/// when usm is used and dimension is greater than 1.
|
2844
|
+
template <size_t D = Dimension>
|
2845
|
+
typename std::enable_if<D != 1, dpct_accessor_t>::type
|
2846
|
+
get_access([[maybe_unused]] sycl::handler &cgh) {
|
2847
|
+
return dpct_accessor_t((T *)_device_ptr, _range);
|
2848
|
+
}
|
2849
|
+
|
2850
|
+
private:
|
2851
|
+
device_memory(value_t *memory_ptr, size_t size)
|
2852
|
+
: _size(size), _range(size / sizeof(T)), _reference(true),
|
2853
|
+
_device_ptr(memory_ptr) {}
|
2854
|
+
|
2855
|
+
void allocate_device(sycl::queue &q) {
|
2856
|
+
#ifndef DPCT_USM_LEVEL_NONE
|
2857
|
+
if (Memory == shared) {
|
2858
|
+
_device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(),
|
2859
|
+
q.get_context());
|
2860
|
+
return;
|
2861
|
+
}
|
2862
|
+
#ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY
|
2863
|
+
if (Memory == constant) {
|
2864
|
+
_device_ptr = (value_t *)sycl::malloc_device(
|
2865
|
+
_size, q.get_device(), q.get_context(),
|
2866
|
+
sycl::ext::oneapi::property::usm::device_read_only());
|
2867
|
+
return;
|
2868
|
+
}
|
2869
|
+
#endif
|
2870
|
+
#endif
|
2871
|
+
_device_ptr = (value_t *)detail::dpct_malloc(_size, q);
|
2872
|
+
}
|
2873
|
+
|
2874
|
+
size_t _size;
|
2875
|
+
sycl::range<Dimension> _range;
|
2876
|
+
bool _reference;
|
2877
|
+
value_t *_host_ptr;
|
2878
|
+
value_t *_device_ptr;
|
2879
|
+
};
|
2880
|
+
template <class T, memory_region Memory>
|
2881
|
+
class device_memory<T, Memory, 0> : public device_memory<T, Memory, 1> {
|
2882
|
+
public:
|
2883
|
+
using base = device_memory<T, Memory, 1>;
|
2884
|
+
using value_t = typename base::value_t;
|
2885
|
+
using accessor_t =
|
2886
|
+
typename detail::memory_traits<Memory, T>::template accessor_t<0>;
|
2887
|
+
|
2888
|
+
/// Constructor with initial value.
|
2889
|
+
device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {}
|
2890
|
+
|
2891
|
+
/// Default constructor
|
2892
|
+
device_memory() : base(1) {}
|
2893
|
+
};
|
2894
|
+
} // namespace detail
|
2895
|
+
|
2896
|
+
template <class T, size_t Dimension>
|
2897
|
+
using global_memory = detail::device_memory<T, global, Dimension>;
|
2898
|
+
template <class T, size_t Dimension>
|
2899
|
+
using constant_memory = detail::device_memory<T, constant, Dimension>;
|
2900
|
+
template <class T, size_t Dimension>
|
2901
|
+
using shared_memory = detail::device_memory<T, shared, Dimension>;
|
2902
|
+
|
2903
|
+
|
2904
|
+
template <typename T,
|
2905
|
+
sycl::access::address_space addressSpace =
|
2906
|
+
sycl::access::address_space::global_space,
|
2907
|
+
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
|
2908
|
+
sycl::memory_scope memoryScope = sycl::memory_scope::device>
|
2909
|
+
inline T atomic_fetch_add(T *addr, T operand) {
|
2910
|
+
auto atm =
|
2911
|
+
sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
|
2912
|
+
return atm.fetch_add(operand);
|
2913
|
+
}
|
2914
|
+
|
2915
|
+
template <sycl::access::address_space addressSpace =
|
2916
|
+
sycl::access::address_space::global_space,
|
2917
|
+
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
|
2918
|
+
sycl::memory_scope memoryScope = sycl::memory_scope::device,
|
2919
|
+
typename T1, typename T2>
|
2920
|
+
inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
|
2921
|
+
auto atm =
|
2922
|
+
sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
|
2923
|
+
return atm.fetch_add(operand);
|
2924
|
+
}
|
2925
|
+
|
2926
|
+
template <typename T, sycl::access::address_space addressSpace =
|
2927
|
+
sycl::access::address_space::global_space>
|
2928
|
+
inline T atomic_fetch_add(T *addr, T operand,
|
2929
|
+
sycl::memory_order memoryOrder) {
|
2930
|
+
switch (memoryOrder) {
|
2931
|
+
case sycl::memory_order::relaxed:
|
2932
|
+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
|
2933
|
+
sycl::memory_scope::device>(addr, operand);
|
2934
|
+
case sycl::memory_order::acq_rel:
|
2935
|
+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
|
2936
|
+
sycl::memory_scope::device>(addr, operand);
|
2937
|
+
case sycl::memory_order::seq_cst:
|
2938
|
+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
|
2939
|
+
sycl::memory_scope::device>(addr, operand);
|
2940
|
+
default:
|
2941
|
+
assert(false && "Invalid memory_order for atomics. Valid memory_order for "
|
2942
|
+
"atomics are: sycl::memory_order::relaxed, "
|
2943
|
+
"sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
|
2944
|
+
}
|
2945
|
+
}
|
2946
|
+
|
2947
|
+
template <sycl::access::address_space addressSpace =
|
2948
|
+
sycl::access::address_space::global_space,
|
2949
|
+
typename T1, typename T2>
|
2950
|
+
inline T1 atomic_fetch_add(T1 *addr, T2 operand,
|
2951
|
+
sycl::memory_order memoryOrder) {
|
2952
|
+
atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
|
2953
|
+
}
|
2954
|
+
|
2955
|
+
} // COPY from DPCT head files
|
2956
|
+
|
2957
|
+
#endif // GGML_SYCL_DPCT_HELPER_HPP
|