whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -30,19 +30,29 @@
|
|
|
30
30
|
#include <regex>
|
|
31
31
|
|
|
32
32
|
#include <sycl/sycl.hpp>
|
|
33
|
+
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
|
34
|
+
# include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
|
|
35
|
+
#endif
|
|
33
36
|
#include <sycl/half_type.hpp>
|
|
34
37
|
|
|
35
38
|
#include "ggml-sycl.h"
|
|
36
39
|
#include "ggml-impl.h"
|
|
37
40
|
#include "ggml-backend-impl.h"
|
|
38
41
|
|
|
42
|
+
#include "ggml-sycl/add-id.hpp"
|
|
39
43
|
#include "ggml-sycl/backend.hpp"
|
|
40
44
|
#include "ggml-sycl/common.hpp"
|
|
41
45
|
#include "ggml-sycl/element_wise.hpp"
|
|
46
|
+
#include "ggml-sycl/norm.hpp"
|
|
42
47
|
#include "ggml-sycl/presets.hpp"
|
|
43
48
|
#include "ggml-sycl/gemm.hpp"
|
|
49
|
+
#include "ggml-sycl/set_rows.hpp"
|
|
50
|
+
#include "ggml-sycl/set.hpp"
|
|
44
51
|
#include "ggml-sycl/sycl_hw.hpp"
|
|
45
52
|
#include "ggml-sycl/getrows.hpp"
|
|
53
|
+
#include "ggml-sycl/repeat_back.hpp"
|
|
54
|
+
#include "ggml-sycl/quantize.hpp"
|
|
55
|
+
#include "ggml-sycl/ssm_conv.hpp"
|
|
46
56
|
#include "ggml.h"
|
|
47
57
|
|
|
48
58
|
static bool g_sycl_loaded = false;
|
|
@@ -51,6 +61,7 @@ int g_ggml_sycl_disable_optimize = 0;
|
|
|
51
61
|
int g_ggml_sycl_disable_graph = 0;
|
|
52
62
|
int g_ggml_sycl_disable_dnn = 0;
|
|
53
63
|
int g_ggml_sycl_prioritize_dmmv = 0;
|
|
64
|
+
int g_ggml_sycl_use_async_mem_op = 0;
|
|
54
65
|
|
|
55
66
|
static ggml_sycl_device_info ggml_sycl_init() {
|
|
56
67
|
ggml_sycl_device_info info = {};
|
|
@@ -83,7 +94,10 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
|
|
83
94
|
|
|
84
95
|
info.devices[i].cc =
|
|
85
96
|
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
|
86
|
-
info.devices[i].
|
|
97
|
+
info.devices[i].nsm = prop.get_max_compute_units();
|
|
98
|
+
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
|
99
|
+
info.devices[i].smpbo = prop.get_local_mem_size();
|
|
100
|
+
|
|
87
101
|
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
|
88
102
|
}
|
|
89
103
|
|
|
@@ -231,7 +245,20 @@ static void ggml_check_sycl() try {
|
|
|
231
245
|
fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
|
|
232
246
|
#endif
|
|
233
247
|
*/
|
|
234
|
-
|
|
248
|
+
// Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
|
|
249
|
+
// properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
|
|
250
|
+
// other places.
|
|
251
|
+
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
|
252
|
+
g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
|
|
253
|
+
if (g_ggml_sycl_use_async_mem_op) {
|
|
254
|
+
for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
|
|
255
|
+
if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
|
|
256
|
+
g_ggml_sycl_use_async_mem_op = 0;
|
|
257
|
+
break;
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
}
|
|
261
|
+
#endif
|
|
235
262
|
if (CHECK_TRY_ERROR(g_all_sycl_device_count =
|
|
236
263
|
dpct::dev_mgr::instance().device_count()) != 0) {
|
|
237
264
|
initialized = true;
|
|
@@ -1372,120 +1399,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
|
|
|
1372
1399
|
|
|
1373
1400
|
|
|
1374
1401
|
|
|
1375
|
-
template<int QUANT_BLOCK_TILE>
|
|
1376
|
-
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
|
|
1377
|
-
const sycl::nd_item<3> &item_ct1) {
|
|
1378
|
-
const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
1379
|
-
item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
|
|
1380
|
-
|
|
1381
|
-
if (ix >= kx_padded) {
|
|
1382
|
-
return;
|
|
1383
|
-
}
|
|
1384
|
-
|
|
1385
|
-
const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
1386
|
-
item_ct1.get_local_id(1);
|
|
1387
|
-
|
|
1388
|
-
const int i_padded = iy*kx_padded + ix;
|
|
1389
|
-
|
|
1390
|
-
block_q8_1 * y = (block_q8_1 *) vy;
|
|
1391
|
-
|
|
1392
|
-
const int ib = i_padded / QK8_1; // block index
|
|
1393
|
-
const int iqs = i_padded % QK8_1; // quant index
|
|
1394
|
-
typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
|
|
1395
|
-
typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
|
|
1396
|
-
TC zeros;
|
|
1397
|
-
TQ qzeros;
|
|
1398
|
-
#pragma unroll
|
|
1399
|
-
for (int i = 0; i < QUANT_BLOCK_TILE; i++)
|
|
1400
|
-
{
|
|
1401
|
-
zeros[i] = 0.f;
|
|
1402
|
-
qzeros[i] = 0;
|
|
1403
|
-
}
|
|
1404
|
-
const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
|
|
1405
|
-
float sum = xi[0];
|
|
1406
|
-
float amax = sycl::fabs(xi[0]);
|
|
1407
|
-
#pragma unroll
|
|
1408
|
-
for (int i = 1; i < QUANT_BLOCK_TILE; i++)
|
|
1409
|
-
{
|
|
1410
|
-
sum += xi[i];
|
|
1411
|
-
amax = sycl::fmax(sycl::fabs(xi[i]), amax);
|
|
1412
|
-
}
|
|
1413
|
-
sum = warp_reduce_sum(sum, item_ct1);
|
|
1414
|
-
amax = warp_reduce_max(amax, item_ct1);
|
|
1415
|
-
|
|
1416
|
-
const float d = amax / 127;
|
|
1417
|
-
TQ q = qzeros;
|
|
1418
|
-
if (amax != 0.0f)
|
|
1419
|
-
{
|
|
1420
|
-
#pragma unroll
|
|
1421
|
-
for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
|
|
1422
|
-
q[i] = sycl::round(xi[i] / d);
|
|
1423
|
-
}
|
|
1424
|
-
}
|
|
1425
|
-
|
|
1426
|
-
*(TQ *)&y[ib].qs[iqs] = q;
|
|
1427
|
-
|
|
1428
|
-
if (iqs > 0) {
|
|
1429
|
-
return;
|
|
1430
|
-
}
|
|
1431
|
-
|
|
1432
|
-
reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
|
|
1433
|
-
reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
|
|
1434
|
-
}
|
|
1435
|
-
|
|
1436
|
-
template <int ElementsPerWI>
|
|
1437
|
-
static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
|
|
1438
|
-
const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
|
|
1439
|
-
/*
|
|
1440
|
-
Quantizes and reorders the resultant q8 tensor in a per row fashion
|
|
1441
|
-
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
|
|
1442
|
-
*/
|
|
1443
|
-
|
|
1444
|
-
auto subgroup_id = it.get_group(0);
|
|
1445
|
-
auto wi_id = it.get_local_id(0);
|
|
1446
|
-
|
|
1447
|
-
const int num_blocks_per_row = kx / QK8_1;
|
|
1448
|
-
auto row = subgroup_id / num_blocks_per_row;
|
|
1449
|
-
auto col = subgroup_id % num_blocks_per_row;
|
|
1450
|
-
|
|
1451
|
-
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
|
|
1452
|
-
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
|
|
1453
|
-
|
|
1454
|
-
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
|
|
1455
|
-
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
|
|
1456
|
-
|
|
1457
|
-
sycl::vec<float, ElementsPerWI> wi_f32_vals;
|
|
1458
|
-
sycl::vec<int8_t, ElementsPerWI> quantized_values;
|
|
1459
|
-
|
|
1460
|
-
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
|
|
1461
|
-
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
|
|
1462
|
-
|
|
1463
|
-
float sum = 0.0f;
|
|
1464
|
-
float amax = 0.0f;
|
|
1465
|
-
|
|
1466
|
-
#pragma unroll(ElementsPerWI)
|
|
1467
|
-
for (int i = 0; i < ElementsPerWI; i++) {
|
|
1468
|
-
sum += wi_f32_vals[i];
|
|
1469
|
-
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
|
|
1470
|
-
quantized_values[i] = 0;
|
|
1471
|
-
}
|
|
1472
|
-
sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
|
|
1473
|
-
amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
|
|
1474
|
-
float d = amax == 0 ? 1 : amax / 127;
|
|
1475
|
-
|
|
1476
|
-
#pragma unroll(ElementsPerWI)
|
|
1477
|
-
for (int i = 0; i < ElementsPerWI; i++) {
|
|
1478
|
-
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
|
|
1479
|
-
}
|
|
1480
|
-
|
|
1481
|
-
d = amax == 0 ? 0 : d;
|
|
1482
|
-
|
|
1483
|
-
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
|
|
1484
|
-
if (wi_id == 0) {
|
|
1485
|
-
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
|
|
1486
|
-
}
|
|
1487
|
-
}
|
|
1488
|
-
|
|
1489
1402
|
static void mul_mat_p021_f16_f32(
|
|
1490
1403
|
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
|
|
1491
1404
|
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
|
|
@@ -1545,7 +1458,7 @@ static void mul_mat_p021_f16_f32(
|
|
|
1545
1458
|
|
|
1546
1459
|
static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
1547
1460
|
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
|
|
1548
|
-
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
|
|
1461
|
+
const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
|
|
1549
1462
|
const sycl::nd_item<3> &item_ct1) {
|
|
1550
1463
|
|
|
1551
1464
|
const sycl::half *x = (const sycl::half *)vx;
|
|
@@ -1556,7 +1469,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
|
1556
1469
|
item_ct1.get_local_id(0);
|
|
1557
1470
|
const int channel_x = channel / channel_x_divisor;
|
|
1558
1471
|
|
|
1559
|
-
const int nrows_y = ncols_x;
|
|
1560
1472
|
const int nrows_dst = nrows_x;
|
|
1561
1473
|
const int row_dst = row_x;
|
|
1562
1474
|
|
|
@@ -1575,7 +1487,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
|
1575
1487
|
const int row_y = col_x;
|
|
1576
1488
|
|
|
1577
1489
|
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
|
|
1578
|
-
const int iy = channel*
|
|
1490
|
+
const int iy = channel * channel_stride_y + row_y;
|
|
1579
1491
|
|
|
1580
1492
|
const float xi =
|
|
1581
1493
|
sycl::vec<sycl::half, 1>(x[ix])
|
|
@@ -1624,60 +1536,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
|
|
|
1624
1536
|
template <ggml_sort_order order>
|
|
1625
1537
|
__dpct_inline__ static void
|
|
1626
1538
|
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
|
|
1627
|
-
const sycl::nd_item<3> &item_ct1,
|
|
1539
|
+
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
|
|
1540
|
+
uint8_t *dpct_local) {
|
|
1628
1541
|
// bitonic sort
|
|
1629
|
-
int
|
|
1542
|
+
int col_index = item_ct1.get_local_id(2);
|
|
1630
1543
|
int row = item_ct1.get_group(1);
|
|
1631
1544
|
|
|
1632
|
-
|
|
1633
|
-
|
|
1545
|
+
for (int i = 0; i < tasks_per_thread; i++) {
|
|
1546
|
+
int col = col_index * tasks_per_thread + i;
|
|
1547
|
+
if (col >= ncols_pad) {
|
|
1548
|
+
return;
|
|
1549
|
+
}
|
|
1634
1550
|
}
|
|
1635
1551
|
|
|
1636
1552
|
const float * x_row = x + row * ncols;
|
|
1637
1553
|
auto dst_row = (int *)dpct_local;
|
|
1638
1554
|
|
|
1639
1555
|
// initialize indices
|
|
1640
|
-
|
|
1556
|
+
for (int i=0;i<tasks_per_thread;i++){
|
|
1557
|
+
int col = col_index*tasks_per_thread+i;
|
|
1558
|
+
dst_row[col] = col;
|
|
1559
|
+
}
|
|
1641
1560
|
|
|
1642
1561
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1643
1562
|
|
|
1644
1563
|
for (int k = 2; k <= ncols_pad; k *= 2) {
|
|
1645
1564
|
for (int j = k / 2; j > 0; j /= 2) {
|
|
1646
|
-
int
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1656
|
-
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1565
|
+
for (int i = 0; i < tasks_per_thread; i++) {
|
|
1566
|
+
int col = col_index * tasks_per_thread + i;
|
|
1567
|
+
int ixj = col ^ j;
|
|
1568
|
+
if (ixj > col) {
|
|
1569
|
+
if ((col & k) == 0) {
|
|
1570
|
+
if (dst_row[col] >= ncols ||
|
|
1571
|
+
(dst_row[ixj] < ncols &&
|
|
1572
|
+
(order == GGML_SORT_ORDER_ASC
|
|
1573
|
+
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
|
|
1574
|
+
: x_row[dst_row[col]] <
|
|
1575
|
+
x_row[dst_row[ixj]]))) {
|
|
1576
|
+
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
|
1577
|
+
}
|
|
1578
|
+
} else {
|
|
1579
|
+
if (dst_row[ixj] >= ncols ||
|
|
1580
|
+
(dst_row[col] < ncols &&
|
|
1581
|
+
(order == GGML_SORT_ORDER_ASC
|
|
1582
|
+
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
|
|
1583
|
+
: x_row[dst_row[col]] >
|
|
1584
|
+
x_row[dst_row[ixj]]))) {
|
|
1585
|
+
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
|
1586
|
+
}
|
|
1663
1587
|
}
|
|
1664
1588
|
}
|
|
1589
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1665
1590
|
}
|
|
1666
|
-
/*
|
|
1667
|
-
DPCT1118:1: SYCL group functions and algorithms must be encountered
|
|
1668
|
-
in converged control flow. You may need to adjust the code.
|
|
1669
|
-
*/
|
|
1670
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1671
1591
|
}
|
|
1672
1592
|
}
|
|
1673
1593
|
|
|
1674
1594
|
// copy the result to dst without the padding
|
|
1675
|
-
|
|
1676
|
-
|
|
1595
|
+
for (int i = 0; i < tasks_per_thread; i++) {
|
|
1596
|
+
int col = col_index * tasks_per_thread + i;
|
|
1597
|
+
if (col < ncols) {
|
|
1598
|
+
dst[row * ncols + col] = dst_row[col];
|
|
1599
|
+
}
|
|
1677
1600
|
}
|
|
1678
1601
|
}
|
|
1679
1602
|
|
|
1680
|
-
|
|
1681
1603
|
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
|
|
1682
1604
|
const sycl::nd_item<3> &item_ct1) {
|
|
1683
1605
|
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
@@ -1695,7 +1617,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
|
|
|
1695
1617
|
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
|
1696
1618
|
}
|
|
1697
1619
|
|
|
1698
|
-
static void scale_f32(const float * x, float * dst, const float scale, const int k,
|
|
1620
|
+
static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
|
|
1699
1621
|
const sycl::nd_item<3> &item_ct1) {
|
|
1700
1622
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
1701
1623
|
item_ct1.get_local_id(2);
|
|
@@ -1704,7 +1626,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
|
|
|
1704
1626
|
return;
|
|
1705
1627
|
}
|
|
1706
1628
|
|
|
1707
|
-
dst[i] = scale * x[i];
|
|
1629
|
+
dst[i] = scale * x[i] + bias;
|
|
1708
1630
|
}
|
|
1709
1631
|
|
|
1710
1632
|
|
|
@@ -1770,32 +1692,6 @@ static void pool2d_nchw_kernel(
|
|
|
1770
1692
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
|
1771
1693
|
}
|
|
1772
1694
|
|
|
1773
|
-
static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
|
|
1774
|
-
bool reorder_q8_tensor, queue_ptr stream) {
|
|
1775
|
-
if (reorder_q8_tensor) {
|
|
1776
|
-
auto local_range = std::size_t(WARP_SIZE);
|
|
1777
|
-
auto num_quant_blocks = ky * (kx / QK8_1);
|
|
1778
|
-
auto global_range = num_quant_blocks * local_range;
|
|
1779
|
-
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
|
|
1780
|
-
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1781
|
-
quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
|
|
1782
|
-
});
|
|
1783
|
-
} else {
|
|
1784
|
-
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
|
|
1785
|
-
const sycl::range<3> num_blocks(1, ky, block_num_x);
|
|
1786
|
-
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
|
|
1787
|
-
static_assert(QK8_1 % WARP_SIZE == 0);
|
|
1788
|
-
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
|
|
1789
|
-
{
|
|
1790
|
-
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
1791
|
-
|
|
1792
|
-
stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
|
|
1793
|
-
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1794
|
-
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
|
1795
|
-
});
|
|
1796
|
-
}
|
|
1797
|
-
}
|
|
1798
|
-
}
|
|
1799
1695
|
|
|
1800
1696
|
static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
|
1801
1697
|
float *dst, const int ncols_x,
|
|
@@ -1822,7 +1718,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
|
|
1822
1718
|
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
|
1823
1719
|
const void *vx, const float *y, float *dst, const int ncols_x,
|
|
1824
1720
|
const int nrows_x, const int row_stride_x, const int nchannels_x,
|
|
1825
|
-
const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
|
|
1721
|
+
const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
|
|
1826
1722
|
|
|
1827
1723
|
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
|
|
1828
1724
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
@@ -1834,7 +1730,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
|
|
1834
1730
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1835
1731
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1836
1732
|
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
|
1837
|
-
row_stride_x, channel_stride_x,
|
|
1733
|
+
row_stride_x, channel_stride_x, channel_stride_y,
|
|
1838
1734
|
nchannels_y / nchannels_x, item_ct1);
|
|
1839
1735
|
});
|
|
1840
1736
|
}
|
|
@@ -1842,7 +1738,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
|
|
1842
1738
|
|
|
1843
1739
|
|
|
1844
1740
|
|
|
1845
|
-
static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
|
1741
|
+
static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
|
|
1846
1742
|
const int k, queue_ptr stream) {
|
|
1847
1743
|
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
|
|
1848
1744
|
stream->parallel_for(
|
|
@@ -1850,7 +1746,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
|
|
1850
1746
|
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
|
|
1851
1747
|
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
|
|
1852
1748
|
[=](sycl::nd_item<3> item_ct1) {
|
|
1853
|
-
scale_f32(x, dst, scale, k, item_ct1);
|
|
1749
|
+
scale_f32(x, dst, scale, bias, k, item_ct1);
|
|
1854
1750
|
});
|
|
1855
1751
|
}
|
|
1856
1752
|
|
|
@@ -1876,37 +1772,51 @@ static int next_power_of_2(int x) {
|
|
|
1876
1772
|
|
|
1877
1773
|
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
1878
1774
|
const int nrows, ggml_sort_order order,
|
|
1879
|
-
queue_ptr stream) {
|
|
1775
|
+
queue_ptr stream, int device) {
|
|
1880
1776
|
// bitonic sort requires ncols to be power of 2
|
|
1881
1777
|
const int ncols_pad = next_power_of_2(ncols);
|
|
1882
1778
|
|
|
1883
|
-
|
|
1779
|
+
int nth = 1;
|
|
1780
|
+
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
1781
|
+
while (nth < ncols_pad && nth < max_block_size)
|
|
1782
|
+
nth *= 2;
|
|
1783
|
+
if (nth > max_block_size)
|
|
1784
|
+
nth = max_block_size;
|
|
1785
|
+
|
|
1786
|
+
const int tasks_per_thread = ncols_pad / nth;
|
|
1787
|
+
|
|
1788
|
+
const sycl::range<3> block_dims(1, 1, nth);
|
|
1884
1789
|
const sycl::range<3> block_nums(1, nrows, 1);
|
|
1885
1790
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
|
1791
|
+
GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
|
|
1886
1792
|
|
|
1887
1793
|
if (order == GGML_SORT_ORDER_ASC) {
|
|
1888
|
-
|
|
1794
|
+
stream->submit([&](sycl::handler &cgh) {
|
|
1889
1795
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
|
1890
1796
|
sycl::range<1>(shared_mem), cgh);
|
|
1891
1797
|
|
|
1892
|
-
|
|
1893
|
-
|
|
1798
|
+
cgh.parallel_for(
|
|
1799
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1800
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
1894
1801
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
|
1895
|
-
x, dst, ncols, ncols_pad, item_ct1,
|
|
1896
|
-
dpct_local_acc_ct1
|
|
1802
|
+
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
|
1803
|
+
dpct_local_acc_ct1
|
|
1804
|
+
.get_multi_ptr<sycl::access::decorated::no>()
|
|
1897
1805
|
.get());
|
|
1898
1806
|
});
|
|
1899
1807
|
});
|
|
1900
1808
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
|
1901
|
-
|
|
1809
|
+
stream->submit([&](sycl::handler &cgh) {
|
|
1902
1810
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
|
1903
1811
|
sycl::range<1>(shared_mem), cgh);
|
|
1904
1812
|
|
|
1905
|
-
|
|
1906
|
-
|
|
1813
|
+
cgh.parallel_for(
|
|
1814
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1815
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
1907
1816
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
|
1908
|
-
x, dst, ncols, ncols_pad, item_ct1,
|
|
1909
|
-
dpct_local_acc_ct1
|
|
1817
|
+
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
|
1818
|
+
dpct_local_acc_ct1
|
|
1819
|
+
.get_multi_ptr<sycl::access::decorated::no>()
|
|
1910
1820
|
.get());
|
|
1911
1821
|
});
|
|
1912
1822
|
});
|
|
@@ -1921,47 +1831,50 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
|
1921
1831
|
const sycl::range<3> block_nums(1, nrows, 1);
|
|
1922
1832
|
const size_t shared_mem = 256 * sizeof(float);
|
|
1923
1833
|
|
|
1924
|
-
|
|
1834
|
+
stream->submit([&](sycl::handler &cgh) {
|
|
1925
1835
|
sycl::local_accessor<float, 1> shared_data(
|
|
1926
1836
|
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
|
1927
1837
|
sycl::local_accessor<int, 1> shared_indices(
|
|
1928
1838
|
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
|
1929
1839
|
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1840
|
+
cgh.parallel_for(
|
|
1841
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1842
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
1843
|
+
const int tid = item_ct1.get_local_id(2);
|
|
1844
|
+
const int row = item_ct1.get_global_id(1);
|
|
1845
|
+
|
|
1846
|
+
float max_val = -INFINITY;
|
|
1847
|
+
int max_idx = -1;
|
|
1848
|
+
|
|
1849
|
+
for (int col = tid; col < ncols; col += 256) {
|
|
1850
|
+
float val = x[row * ncols + col];
|
|
1851
|
+
if (val > max_val) {
|
|
1852
|
+
max_val = val;
|
|
1853
|
+
max_idx = col;
|
|
1854
|
+
}
|
|
1942
1855
|
}
|
|
1943
|
-
}
|
|
1944
1856
|
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1857
|
+
shared_data[tid] = max_val;
|
|
1858
|
+
shared_indices[tid] = max_idx;
|
|
1859
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1860
|
+
|
|
1861
|
+
for (int stride = 256/2; stride > 0; stride >>= 1) {
|
|
1862
|
+
if (tid < stride) {
|
|
1863
|
+
float val1 = shared_data[tid];
|
|
1864
|
+
float val2 = shared_data[tid + stride];
|
|
1865
|
+
if (val2 > val1) {
|
|
1866
|
+
shared_data[tid] = val2;
|
|
1867
|
+
shared_indices[tid] = shared_indices[tid + stride];
|
|
1868
|
+
}
|
|
1956
1869
|
}
|
|
1870
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1957
1871
|
}
|
|
1958
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1959
|
-
}
|
|
1960
1872
|
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1873
|
+
|
|
1874
|
+
if (tid == 0) {
|
|
1875
|
+
dst[row] = shared_indices[0];
|
|
1876
|
+
}
|
|
1877
|
+
});
|
|
1965
1878
|
});
|
|
1966
1879
|
}
|
|
1967
1880
|
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
|
@@ -2123,8 +2036,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
|
2123
2036
|
|
|
2124
2037
|
#if GGML_SYCL_DNNL
|
|
2125
2038
|
if (!g_ggml_sycl_disable_dnn) {
|
|
2126
|
-
|
|
2127
|
-
|
|
2039
|
+
DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
|
|
2040
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
|
2128
2041
|
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
|
2129
2042
|
}
|
|
2130
2043
|
else
|
|
@@ -2170,8 +2083,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
|
2170
2083
|
|
|
2171
2084
|
#if GGML_SYCL_DNNL
|
|
2172
2085
|
if (!g_ggml_sycl_disable_dnn) {
|
|
2173
|
-
DnnlGemmWrapper::row_gemm(ctx,
|
|
2174
|
-
DnnlGemmWrapper::to_dt<float>(),
|
|
2086
|
+
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
|
|
2087
|
+
DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
|
2175
2088
|
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
|
2176
2089
|
}
|
|
2177
2090
|
else
|
|
@@ -2261,6 +2174,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
|
|
|
2261
2174
|
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
|
2262
2175
|
}
|
|
2263
2176
|
|
|
2177
|
+
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2178
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
2179
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
2180
|
+
|
|
2181
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
2182
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
2183
|
+
|
|
2184
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
2185
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
|
2186
|
+
|
|
2187
|
+
const int64_t ncols = dst->src[0]->ne[0];
|
|
2188
|
+
const int64_t nrows = ggml_nrows(dst->src[0]);
|
|
2189
|
+
|
|
2190
|
+
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
|
2191
|
+
|
|
2192
|
+
main_stream->parallel_for(
|
|
2193
|
+
sycl::range<1>(nrows),
|
|
2194
|
+
[=](sycl::id<1> row) {
|
|
2195
|
+
dst_dd[row] /= ncols;
|
|
2196
|
+
}
|
|
2197
|
+
);
|
|
2198
|
+
}
|
|
2199
|
+
|
|
2200
|
+
|
|
2264
2201
|
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2265
2202
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
2266
2203
|
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
|
@@ -2275,7 +2212,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
|
|
|
2275
2212
|
|
|
2276
2213
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
|
2277
2214
|
|
|
2278
|
-
argsort_f32_i32_sycl(src0_dd, (int *)
|
|
2215
|
+
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
|
|
2216
|
+
main_stream, ctx.device);
|
|
2279
2217
|
}
|
|
2280
2218
|
|
|
2281
2219
|
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
@@ -2319,9 +2257,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|
|
2319
2257
|
float * dst_dd = static_cast<float *>(dst->data);
|
|
2320
2258
|
|
|
2321
2259
|
float scale;
|
|
2322
|
-
|
|
2260
|
+
float bias;
|
|
2261
|
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
2262
|
+
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
|
|
2323
2263
|
|
|
2324
|
-
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
|
|
2264
|
+
scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
|
|
2325
2265
|
/*
|
|
2326
2266
|
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
|
|
2327
2267
|
error codes. The call was replaced with 0. You need to rewrite this code.
|
|
@@ -2370,10 +2310,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
|
|
2370
2310
|
peer_access_enabled = enable_peer_access;
|
|
2371
2311
|
}
|
|
2372
2312
|
|
|
2313
|
+
template <template <int> typename quantize_f>
|
|
2373
2314
|
static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
2374
2315
|
const ggml_tensor *src1, ggml_tensor *dst,
|
|
2375
|
-
ggml_sycl_op_mul_mat_t op
|
|
2376
|
-
const bool convert_src1_to_q8_1) try {
|
|
2316
|
+
ggml_sycl_op_mul_mat_t op) try {
|
|
2377
2317
|
|
|
2378
2318
|
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
|
|
2379
2319
|
|
|
@@ -2468,6 +2408,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2468
2408
|
}
|
|
2469
2409
|
}
|
|
2470
2410
|
|
|
2411
|
+
constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
|
|
2412
|
+
no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
|
|
2471
2413
|
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
|
|
2472
2414
|
if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
|
|
2473
2415
|
continue;
|
|
@@ -2493,20 +2435,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2493
2435
|
dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
|
|
2494
2436
|
}
|
|
2495
2437
|
|
|
2496
|
-
if (
|
|
2438
|
+
if constexpr(quantize_enabled) {
|
|
2497
2439
|
dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
|
|
2498
2440
|
|
|
2499
2441
|
if (src1_on_device && src1_is_contiguous) {
|
|
2500
|
-
bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
|
|
2501
2442
|
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
|
|
2502
2443
|
/*num_src=*/2, " : converting src1 to Q8_1");
|
|
2503
|
-
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
|
|
2507
|
-
|
|
2508
|
-
|
|
2509
|
-
|
|
2444
|
+
try {
|
|
2445
|
+
quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
|
|
2446
|
+
} catch (sycl::exception const &exc) {
|
|
2447
|
+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
|
|
2448
|
+
<< ", line:" << __LINE__ << std::endl;
|
|
2449
|
+
std::exit(1);
|
|
2450
|
+
}
|
|
2510
2451
|
}
|
|
2511
2452
|
}
|
|
2512
2453
|
|
|
@@ -2522,11 +2463,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2522
2463
|
// here an event is recorded that signals that the main device has finished calculating the input data
|
|
2523
2464
|
if (split && used_devices > 1) {
|
|
2524
2465
|
ggml_sycl_set_device(ctx.device);
|
|
2525
|
-
/*
|
|
2526
|
-
DPCT1024:91: The original code returned the error code that was further
|
|
2527
|
-
consumed by the program logic. This original code was replaced with 0.
|
|
2528
|
-
You may need to rewrite the program logic consuming the error code.
|
|
2529
|
-
*/
|
|
2530
2466
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
|
2531
2467
|
*src0_extra->events[ctx.device][0] =
|
|
2532
2468
|
ctx.stream()->ext_oneapi_submit_barrier()));
|
|
@@ -2550,11 +2486,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2550
2486
|
|
|
2551
2487
|
// wait for main GPU data if necessary
|
|
2552
2488
|
if (split && (i != ctx.device || is != 0)) {
|
|
2553
|
-
/*
|
|
2554
|
-
DPCT1009:163: SYCL uses exceptions to report errors and does not
|
|
2555
|
-
use the error codes. The original code was commented out and a
|
|
2556
|
-
warning string was inserted. You need to rewrite this code.
|
|
2557
|
-
*/
|
|
2558
2489
|
SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
|
|
2559
2490
|
{*src0_extra->events[ctx.device][0]})));
|
|
2560
2491
|
}
|
|
@@ -2580,39 +2511,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2580
2511
|
// copy src0, src1 to device if necessary
|
|
2581
2512
|
if (src1_is_contiguous) {
|
|
2582
2513
|
if (i != ctx.device) {
|
|
2583
|
-
if (
|
|
2514
|
+
if constexpr (quantize_enabled) {
|
|
2584
2515
|
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
|
|
2585
|
-
|
|
2586
|
-
|
|
2587
|
-
|
|
2588
|
-
|
|
2516
|
+
SYCL_CHECK(
|
|
2517
|
+
CHECK_TRY_ERROR(stream
|
|
2518
|
+
->memcpy(src1_ddq_i, src1_ddq_i_source,
|
|
2519
|
+
src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
|
|
2520
|
+
.wait()));
|
|
2589
2521
|
} else {
|
|
2590
|
-
|
|
2591
2522
|
float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
|
|
2592
|
-
src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
|
|
2523
|
+
src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
|
|
2593
2524
|
|
|
2594
|
-
SYCL_CHECK(
|
|
2595
|
-
src1_ddf_i, src1_ddf_i_source,
|
|
2596
|
-
|
|
2525
|
+
SYCL_CHECK(
|
|
2526
|
+
CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
|
|
2527
|
+
src1_ncols * ne10 * sizeof(float))));
|
|
2597
2528
|
}
|
|
2598
2529
|
}
|
|
2599
|
-
} else if (src1_on_device && !src1_is_contiguous) {
|
|
2600
|
-
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
|
|
2601
|
-
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
|
|
2602
2530
|
} else {
|
|
2603
|
-
|
|
2604
|
-
|
|
2531
|
+
if (src1_on_device) {
|
|
2532
|
+
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
|
|
2533
|
+
src1_col_0 + src1_ncols, stream));
|
|
2534
|
+
} else {
|
|
2535
|
+
GGML_ABORT("src1 is non-contiguous and not on device");
|
|
2536
|
+
}
|
|
2605
2537
|
|
|
2606
|
-
|
|
2607
|
-
|
|
2608
|
-
|
|
2609
|
-
|
|
2610
|
-
|
|
2611
|
-
|
|
2612
|
-
|
|
2613
|
-
|
|
2614
|
-
|
|
2615
|
-
|
|
2538
|
+
if constexpr (quantize_enabled) {
|
|
2539
|
+
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
|
|
2540
|
+
/*num_src=*/2, " : converting src1 to Q8_1");
|
|
2541
|
+
try {
|
|
2542
|
+
quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
|
|
2543
|
+
src1_padded_col_size, stream);
|
|
2544
|
+
} catch (const sycl::exception & exc) {
|
|
2545
|
+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
|
|
2546
|
+
<< "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
|
2547
|
+
std::exit(1);
|
|
2548
|
+
}
|
|
2549
|
+
}
|
|
2616
2550
|
}
|
|
2617
2551
|
|
|
2618
2552
|
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
|
|
@@ -2624,12 +2558,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2624
2558
|
// do the computation
|
|
2625
2559
|
SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
|
|
2626
2560
|
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
|
|
2627
|
-
/*
|
|
2628
|
-
DPCT1010:93: SYCL uses exceptions to report errors and does not
|
|
2629
|
-
use the error codes. The call was replaced with 0. You need to
|
|
2630
|
-
rewrite this code.
|
|
2631
|
-
*/
|
|
2632
|
-
SYCL_CHECK(0);
|
|
2633
2561
|
|
|
2634
2562
|
// copy dst to host or other device if necessary
|
|
2635
2563
|
if (!dst_on_device) {
|
|
@@ -2660,12 +2588,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2660
2588
|
|
|
2661
2589
|
// add event for the main device to wait on until other device is done
|
|
2662
2590
|
if (split && (i != ctx.device || is != 0)) {
|
|
2663
|
-
/*
|
|
2664
|
-
DPCT1024:94: The original code returned the error code that
|
|
2665
|
-
was further consumed by the program logic. This original
|
|
2666
|
-
code was replaced with 0. You may need to rewrite the
|
|
2667
|
-
program logic consuming the error code.
|
|
2668
|
-
*/
|
|
2669
2591
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
|
2670
2592
|
*src0_extra->events[i][is] =
|
|
2671
2593
|
stream->ext_oneapi_submit_barrier()));
|
|
@@ -2698,6 +2620,10 @@ catch (sycl::exception const &exc) {
|
|
|
2698
2620
|
std::exit(1);
|
|
2699
2621
|
}
|
|
2700
2622
|
|
|
2623
|
+
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2624
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
2625
|
+
ggml_sycl_op_repeat_back(ctx, dst);
|
|
2626
|
+
}
|
|
2701
2627
|
|
|
2702
2628
|
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2703
2629
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
|
@@ -2714,6 +2640,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|
|
2714
2640
|
ggml_sycl_op_rms_norm(ctx, dst);
|
|
2715
2641
|
}
|
|
2716
2642
|
|
|
2643
|
+
static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2644
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
|
2645
|
+
ggml_sycl_op_rms_norm_back(ctx, dst);
|
|
2646
|
+
}
|
|
2647
|
+
|
|
2717
2648
|
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2718
2649
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
2719
2650
|
ggml_sycl_op_l2_norm(ctx, dst);
|
|
@@ -2764,6 +2695,8 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
|
2764
2695
|
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
|
2765
2696
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
2766
2697
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
2698
|
+
GGML_ASSERT(src1->ne[1] == 1);
|
|
2699
|
+
GGML_ASSERT(src1->ne[3] == 1);
|
|
2767
2700
|
|
|
2768
2701
|
const int64_t ne00 = src0->ne[0];
|
|
2769
2702
|
const int64_t ne01 = src0->ne[1];
|
|
@@ -2773,6 +2706,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
|
2773
2706
|
const int64_t nb02 = src0->nb[2];
|
|
2774
2707
|
|
|
2775
2708
|
const int64_t ne12 = src1->ne[2];
|
|
2709
|
+
const int64_t nb11 = src1->nb[1];
|
|
2776
2710
|
|
|
2777
2711
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
2778
2712
|
queue_ptr main_stream = ctx.stream();
|
|
@@ -2783,8 +2717,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
|
2783
2717
|
|
|
2784
2718
|
const int64_t row_stride_x = nb01 / sizeof(sycl::half);
|
|
2785
2719
|
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
|
|
2720
|
+
const int64_t channel_stride_y = nb11 / sizeof(float);
|
|
2786
2721
|
|
|
2787
|
-
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
|
2722
|
+
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
|
|
2788
2723
|
}
|
|
2789
2724
|
catch (sycl::exception const &exc) {
|
|
2790
2725
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
|
@@ -2838,8 +2773,11 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2838
2773
|
float * dst_ddf = static_cast<float *>(dst->data);
|
|
2839
2774
|
|
|
2840
2775
|
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
|
2776
|
+
const size_t type_size_src0 = ggml_type_size(src0->type);
|
|
2841
2777
|
const size_t type_size_src1 = ggml_type_size(src1->type);
|
|
2842
|
-
|
|
2778
|
+
|
|
2779
|
+
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
|
|
2780
|
+
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
|
|
2843
2781
|
|
|
2844
2782
|
// SRC1 strides
|
|
2845
2783
|
int64_t s11 = nb11 / type_size_src1;
|
|
@@ -2851,16 +2789,47 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2851
2789
|
if (src1->type != GGML_TYPE_F16) {
|
|
2852
2790
|
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
|
|
2853
2791
|
" : converting src1 to fp16");
|
|
2854
|
-
|
|
2855
|
-
|
|
2792
|
+
|
|
2793
|
+
// iterate tensor dims and find the slowest moving dim and stride
|
|
2794
|
+
int last_dim=0;
|
|
2795
|
+
int last_str=0;
|
|
2796
|
+
size_t largest_str=0;
|
|
2797
|
+
for(int i = 0; i< 4; i++){
|
|
2798
|
+
// last stride is always the largest
|
|
2799
|
+
if(src1->nb[i] == largest_str){
|
|
2800
|
+
if(src1->ne[last_dim] == 1){
|
|
2801
|
+
last_str = i;
|
|
2802
|
+
last_dim = i;
|
|
2803
|
+
}
|
|
2804
|
+
}
|
|
2805
|
+
if(src1->nb[i] > largest_str){
|
|
2806
|
+
largest_str = src1->nb[i];
|
|
2807
|
+
last_str = i;
|
|
2808
|
+
last_dim = i;
|
|
2809
|
+
}
|
|
2810
|
+
|
|
2811
|
+
}
|
|
2812
|
+
#if GGML_SYCL_DNNL
|
|
2813
|
+
// oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
|
|
2814
|
+
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
|
|
2815
|
+
src1_f16_alloc.alloc(ne_src1);
|
|
2816
|
+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
|
2817
|
+
GGML_ASSERT(to_fp16_sycl != nullptr);
|
|
2818
|
+
to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
|
|
2819
|
+
# else
|
|
2856
2820
|
const int64_t ne_src1 = ggml_nelements(src1);
|
|
2857
2821
|
src1_f16_alloc.alloc(ne_src1);
|
|
2822
|
+
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
|
2823
|
+
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
|
2858
2824
|
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
|
2825
|
+
#endif
|
|
2859
2826
|
|
|
2860
2827
|
src1_f16 = src1_f16_alloc.get();
|
|
2861
2828
|
s11 = ne10;
|
|
2862
2829
|
s12 = ne11 * s11;
|
|
2863
2830
|
s13 = ne12 * s12;
|
|
2831
|
+
|
|
2832
|
+
is_src1_cont_2 = true;
|
|
2864
2833
|
}
|
|
2865
2834
|
|
|
2866
2835
|
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
|
@@ -2889,48 +2858,115 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2889
2858
|
|
|
2890
2859
|
#if GGML_SYCL_DNNL
|
|
2891
2860
|
if (!g_ggml_sycl_disable_dnn) {
|
|
2892
|
-
|
|
2893
|
-
|
|
2894
|
-
|
|
2895
|
-
|
|
2896
|
-
|
|
2897
|
-
|
|
2898
|
-
|
|
2899
|
-
|
|
2900
|
-
|
|
2901
|
-
|
|
2902
|
-
|
|
2903
|
-
|
|
2904
|
-
|
|
2905
|
-
|
|
2906
|
-
|
|
2907
|
-
|
|
2908
|
-
|
|
2909
|
-
|
|
2910
|
-
|
|
2861
|
+
int64_t str_a0 = nb00 / type_size_src0;
|
|
2862
|
+
int64_t str_a1 = nb01 / type_size_src0;
|
|
2863
|
+
int64_t str_a2 = nb02 / type_size_src0;
|
|
2864
|
+
|
|
2865
|
+
int64_t str_b0 = nb10 / type_size_src1;
|
|
2866
|
+
int64_t str_b1 = nb11 / type_size_src1;
|
|
2867
|
+
int64_t str_b2 = nb12 / type_size_src1;
|
|
2868
|
+
|
|
2869
|
+
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
|
|
2870
|
+
const sycl::half *src1, float *dst,
|
|
2871
|
+
int64_t a0, int64_t a1, int64_t batcha,
|
|
2872
|
+
int64_t /*b0*/, int64_t b1, int64_t batchb,
|
|
2873
|
+
int64_t sa0, int64_t sa1, int64_t sa2,
|
|
2874
|
+
int64_t sb0, int64_t sb1, int64_t sb2,
|
|
2875
|
+
int64_t sd2) {
|
|
2876
|
+
bool supported_broadcast = batchb == batcha ? true
|
|
2877
|
+
: batchb == 1 || batcha == 1 ? true
|
|
2878
|
+
: false;
|
|
2879
|
+
if (supported_broadcast) {
|
|
2880
|
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
|
|
2881
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
|
|
2882
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
|
|
2883
|
+
DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
|
|
2884
|
+
} else {
|
|
2885
|
+
// iterate over batches from smaller set of matrices (matrix 0)
|
|
2886
|
+
int64_t batches0 = batcha;
|
|
2887
|
+
int64_t batches1 = batchb;
|
|
2888
|
+
|
|
2889
|
+
if (batches0 > batches1) {
|
|
2890
|
+
int64_t num_mul_mats = batches1;
|
|
2891
|
+
int64_t sub_batch = batches0 / num_mul_mats;
|
|
2892
|
+
// src0 is batched and bigger, shift and multiply with src1
|
|
2893
|
+
for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
|
|
2894
|
+
const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
|
|
2895
|
+
const sycl::half *src1_shifted = src1 + (sb2 * i0);
|
|
2896
|
+
float *dst_shifted = dst + (sd2 * i0 * sub_batch);
|
|
2897
|
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
|
|
2898
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
|
|
2899
|
+
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
|
|
2900
|
+
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
|
|
2901
|
+
queue, sub_batch, 1);
|
|
2902
|
+
}
|
|
2903
|
+
} else {
|
|
2904
|
+
int64_t num_mul_mats = batches0;
|
|
2905
|
+
int64_t sub_batch = batches1 / num_mul_mats;
|
|
2906
|
+
// src1 is batched and bigger, shift and multiply with src0
|
|
2907
|
+
for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
|
|
2908
|
+
const sycl::half *src0_shifted = src0 + (sa2 * i1);
|
|
2909
|
+
const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
|
|
2910
|
+
float *dst_shifted = dst + (sd2 * i1 * sub_batch);
|
|
2911
|
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
|
|
2912
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
|
|
2913
|
+
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
|
|
2914
|
+
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
|
|
2915
|
+
queue, 1, sub_batch);
|
|
2916
|
+
}
|
|
2917
|
+
}
|
|
2911
2918
|
}
|
|
2912
|
-
}
|
|
2913
|
-
|
|
2914
|
-
|
|
2915
|
-
|
|
2916
|
-
|
|
2917
|
-
|
|
2918
|
-
|
|
2919
|
-
|
|
2920
|
-
|
|
2919
|
+
};
|
|
2920
|
+
|
|
2921
|
+
const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
|
|
2922
|
+
const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
|
|
2923
|
+
const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
|
|
2924
|
+
const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
|
|
2925
|
+
if (cont_batches_dim2_a && cont_batches_dim2_b) {
|
|
2926
|
+
// A batch is considered contiguous if the dimension 2 is not strided
|
|
2927
|
+
int64_t batches0 = ne02 * ne03;
|
|
2928
|
+
int64_t batches1 = ne12 * ne13;
|
|
2929
|
+
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
|
|
2930
|
+
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
|
|
2931
|
+
str_b2, nb2 / sizeof(float));
|
|
2932
|
+
} else if (cont_batches_dim3_a && cont_batches_dim3_b) {
|
|
2933
|
+
// This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
|
|
2934
|
+
int64_t batches0 = ne02 * ne03;
|
|
2935
|
+
int64_t batches1 = ne12 * ne13;
|
|
2936
|
+
int64_t str_a3 = nb03 / type_size_src0;
|
|
2937
|
+
int64_t str_b3 = nb13 / type_size_src1;
|
|
2938
|
+
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
|
|
2939
|
+
ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
|
|
2940
|
+
str_b3, nb2 / sizeof(float));
|
|
2941
|
+
} else {
|
|
2942
|
+
for (int64_t b_a = 0; b_a < ne03; b_a++) {
|
|
2943
|
+
const sycl::half *src0_f16_shifted
|
|
2944
|
+
= src0_f16 + (nb03 * b_a / type_size_src0);
|
|
2945
|
+
const sycl::half *src1_f16_shifted
|
|
2946
|
+
= src1_f16 + (nb13 * b_a / type_size_src1);
|
|
2947
|
+
float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
|
|
2948
|
+
int64_t batches0 = ne02;
|
|
2949
|
+
int64_t batches1 = ne12;
|
|
2950
|
+
launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
|
|
2951
|
+
ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
|
|
2952
|
+
str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
|
|
2921
2953
|
}
|
|
2922
2954
|
}
|
|
2923
|
-
|
|
2955
|
+
|
|
2924
2956
|
}
|
|
2925
2957
|
else
|
|
2926
2958
|
#endif
|
|
2927
2959
|
{
|
|
2928
|
-
if (r2 == 1 && r3 == 1 &&
|
|
2960
|
+
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
|
|
2961
|
+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
|
|
2962
|
+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
|
|
2963
|
+
const int64_t smb = ne12 == 1 ? s13 : s12;
|
|
2964
|
+
|
|
2929
2965
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
|
2930
2966
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
|
2931
2967
|
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
2932
|
-
src0_f16, dpct::library_data_t::real_half, nb01 / nb00,
|
|
2933
|
-
src1_f16, dpct::library_data_t::real_half, s11,
|
|
2968
|
+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
|
|
2969
|
+
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
|
|
2934
2970
|
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
|
2935
2971
|
} else {
|
|
2936
2972
|
const int ne23 = ne12 * ne13;
|
|
@@ -2945,7 +2981,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2945
2981
|
void ** ptrs_dst_get = ptrs_dst.get();
|
|
2946
2982
|
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
|
2947
2983
|
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
|
2948
|
-
|
|
2984
|
+
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
2949
2985
|
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
|
2950
2986
|
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
|
2951
2987
|
});
|
|
@@ -3026,19 +3062,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
|
|
3026
3062
|
}
|
|
3027
3063
|
}
|
|
3028
3064
|
|
|
3065
|
+
// Helper functions to unify device memory allocation for both async and sync paths
|
|
3066
|
+
static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
|
|
3067
|
+
bool use_async = g_ggml_sycl_use_async_mem_op;
|
|
3068
|
+
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
|
3069
|
+
if (use_async) {
|
|
3070
|
+
return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
|
|
3071
|
+
}
|
|
3072
|
+
#else
|
|
3073
|
+
// If async allocation extension is not available, use_async should always be false.
|
|
3074
|
+
GGML_ASSERT(!use_async);
|
|
3075
|
+
#endif
|
|
3076
|
+
return sycl::malloc(size, *stream, sycl::usm::alloc::device);
|
|
3077
|
+
}
|
|
3078
|
+
|
|
3079
|
+
static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
|
|
3080
|
+
bool use_async = g_ggml_sycl_use_async_mem_op;
|
|
3081
|
+
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
|
3082
|
+
if (use_async) {
|
|
3083
|
+
syclex::async_free(*stream, ptr);
|
|
3084
|
+
return;
|
|
3085
|
+
}
|
|
3086
|
+
#else
|
|
3087
|
+
// If async allocation extension is not available, use_async should always be false.
|
|
3088
|
+
GGML_ASSERT(!use_async);
|
|
3089
|
+
#endif
|
|
3090
|
+
sycl::free(ptr, *stream);
|
|
3091
|
+
}
|
|
3092
|
+
|
|
3029
3093
|
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
|
|
3030
3094
|
dpct::queue_ptr stream) {
|
|
3031
|
-
|
|
3032
|
-
|
|
3033
|
-
|
|
3034
|
-
|
|
3095
|
+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
|
3096
|
+
|
|
3097
|
+
sycl::event copy_event;
|
|
3098
|
+
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
|
3099
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3100
|
+
copy_event.wait();
|
|
3101
|
+
}
|
|
3102
|
+
|
|
3035
3103
|
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
|
3036
3104
|
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
|
3037
3105
|
int offset_blks = offset / sizeof(block_q4_0);
|
|
3038
3106
|
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
|
|
3039
3107
|
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
|
3040
3108
|
|
|
3041
|
-
stream->parallel_for(
|
|
3109
|
+
auto reorder_event = stream->parallel_for(
|
|
3042
3110
|
size / sizeof(block_q4_0),
|
|
3043
3111
|
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
3044
3112
|
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
|
@@ -3049,9 +3117,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
|
|
|
3049
3117
|
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
|
|
3050
3118
|
}
|
|
3051
3119
|
*(d_ptr + ib) = x[ib].d;
|
|
3052
|
-
})
|
|
3053
|
-
|
|
3054
|
-
|
|
3120
|
+
});
|
|
3121
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3122
|
+
reorder_event.wait_and_throw();
|
|
3123
|
+
}
|
|
3124
|
+
sycl_ext_free(stream, tmp_buf);
|
|
3055
3125
|
}
|
|
3056
3126
|
|
|
3057
3127
|
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
|
@@ -3060,14 +3130,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
|
|
3060
3130
|
|
|
3061
3131
|
const int nblocks = size / sizeof(block_q4_K);
|
|
3062
3132
|
|
|
3063
|
-
|
|
3064
|
-
|
|
3133
|
+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
|
3134
|
+
|
|
3135
|
+
sycl::event copy_event;
|
|
3136
|
+
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
|
3137
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3138
|
+
copy_event.wait();
|
|
3139
|
+
}
|
|
3065
3140
|
|
|
3066
3141
|
auto * qs_ptr = data_device;
|
|
3067
3142
|
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
|
|
3068
3143
|
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
|
|
3069
3144
|
|
|
3070
|
-
stream->parallel_for(nblocks, [=](auto i) {
|
|
3145
|
+
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
|
|
3071
3146
|
const block_q4_K * x = (const block_q4_K *) tmp_buf;
|
|
3072
3147
|
const int ib = i;
|
|
3073
3148
|
|
|
@@ -3080,9 +3155,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
|
|
3080
3155
|
}
|
|
3081
3156
|
|
|
3082
3157
|
dm_ptr[ib] = x[ib].dm;
|
|
3083
|
-
})
|
|
3084
|
-
|
|
3085
|
-
|
|
3158
|
+
});
|
|
3159
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3160
|
+
reorder_event.wait_and_throw();
|
|
3161
|
+
}
|
|
3162
|
+
sycl_ext_free(stream, tmp_buf);
|
|
3086
3163
|
}
|
|
3087
3164
|
|
|
3088
3165
|
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
|
@@ -3091,42 +3168,46 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
|
|
|
3091
3168
|
|
|
3092
3169
|
const int nblocks = size / sizeof(block_q6_K);
|
|
3093
3170
|
|
|
3094
|
-
|
|
3095
|
-
|
|
3171
|
+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
|
3172
|
+
|
|
3173
|
+
sycl::event copy_event;
|
|
3174
|
+
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
|
3175
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3176
|
+
copy_event.wait();
|
|
3177
|
+
}
|
|
3096
3178
|
|
|
3097
3179
|
auto * ql_ptr = data_device;
|
|
3098
3180
|
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
|
|
3099
3181
|
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
|
|
3100
3182
|
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
|
|
3101
3183
|
|
|
3102
|
-
stream
|
|
3103
|
-
|
|
3104
|
-
|
|
3105
|
-
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
|
3106
|
-
const int ib = i;
|
|
3107
|
-
|
|
3108
|
-
const uint8_t * ql = x[ib].ql;
|
|
3109
|
-
const uint8_t * qh = x[ib].qh;
|
|
3110
|
-
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
|
|
3111
|
-
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
|
|
3112
|
-
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
|
|
3184
|
+
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
|
|
3185
|
+
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
|
3186
|
+
const int ib = i;
|
|
3113
3187
|
|
|
3114
|
-
|
|
3115
|
-
|
|
3116
|
-
|
|
3117
|
-
|
|
3118
|
-
|
|
3119
|
-
}
|
|
3188
|
+
const uint8_t * ql = x[ib].ql;
|
|
3189
|
+
const uint8_t * qh = x[ib].qh;
|
|
3190
|
+
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
|
|
3191
|
+
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
|
|
3192
|
+
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
|
|
3120
3193
|
|
|
3121
|
-
|
|
3122
|
-
|
|
3123
|
-
|
|
3194
|
+
for (int j = 0; j < QK_K / 2; ++j) {
|
|
3195
|
+
base_ql_ptr[j] = ql[j];
|
|
3196
|
+
}
|
|
3197
|
+
for (int j = 0; j < QK_K / 4; ++j) {
|
|
3198
|
+
base_qh_ptr[j] = qh[j];
|
|
3199
|
+
}
|
|
3124
3200
|
|
|
3125
|
-
|
|
3126
|
-
|
|
3127
|
-
|
|
3201
|
+
for (int j = 0; j < QK_K / 16; ++j) {
|
|
3202
|
+
base_scales_ptr[j] = x[ib].scales[j];
|
|
3203
|
+
}
|
|
3128
3204
|
|
|
3129
|
-
|
|
3205
|
+
dm_ptr[ib] = x[ib].d;
|
|
3206
|
+
});
|
|
3207
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3208
|
+
reorder_event.wait_and_throw();
|
|
3209
|
+
}
|
|
3210
|
+
sycl_ext_free(stream, tmp_buf);
|
|
3130
3211
|
}
|
|
3131
3212
|
|
|
3132
3213
|
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
|
@@ -3233,6 +3314,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
3233
3314
|
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
|
3234
3315
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
|
3235
3316
|
|
|
3317
|
+
|
|
3236
3318
|
// mmvq and mmq need the __dp4a instruction which is available for gen12+
|
|
3237
3319
|
// Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
|
|
3238
3320
|
use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
|
|
@@ -3240,7 +3322,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
3240
3322
|
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
|
3241
3323
|
#endif // SYCL_USE_XMX
|
|
3242
3324
|
|
|
3243
|
-
|
|
3244
3325
|
// mmvq path is faster in the CUDA backend.
|
|
3245
3326
|
if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
|
|
3246
3327
|
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
|
|
@@ -3260,26 +3341,27 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
3260
3341
|
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
|
3261
3342
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
|
3262
3343
|
}
|
|
3263
|
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) &&
|
|
3344
|
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
|
|
3264
3345
|
// KQV single-batch
|
|
3265
3346
|
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
|
3266
|
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
3347
|
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
|
|
3267
3348
|
// KQ + KQV multi-batch
|
|
3268
3349
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
|
3269
3350
|
} else if (use_dequantize_mul_mat_vec) {
|
|
3270
|
-
constexpr bool convert_src1_to_q8_1 = false;
|
|
3271
3351
|
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
|
|
3272
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec
|
|
3352
|
+
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
|
|
3273
3353
|
} else if (use_mul_mat_vec_q) {
|
|
3274
|
-
constexpr bool convert_src1_to_q8_1 = true;
|
|
3275
3354
|
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
|
|
3276
|
-
|
|
3355
|
+
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
|
|
3356
|
+
if (extra && extra->optimized_feature.reorder) {
|
|
3357
|
+
ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
|
|
3358
|
+
} else {
|
|
3359
|
+
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
|
|
3360
|
+
}
|
|
3277
3361
|
} else if (use_mul_mat_q) {
|
|
3278
|
-
|
|
3279
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
|
3362
|
+
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
|
|
3280
3363
|
} else {
|
|
3281
|
-
|
|
3282
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
|
3364
|
+
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
|
|
3283
3365
|
}
|
|
3284
3366
|
}
|
|
3285
3367
|
|
|
@@ -3446,10 +3528,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|
|
3446
3528
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
|
3447
3529
|
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
|
|
3448
3530
|
|
|
3531
|
+
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
|
3532
|
+
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
3533
|
+
|
|
3449
3534
|
{
|
|
3450
|
-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10,
|
|
3535
|
+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
|
|
3451
3536
|
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
|
|
3452
|
-
|
|
3537
|
+
stream->submit([&](sycl::handler &cgh) {
|
|
3453
3538
|
sycl::local_accessor<int, 0> src1_row_acc(cgh);
|
|
3454
3539
|
|
|
3455
3540
|
char *__restrict src1_contiguous_get =
|
|
@@ -3461,8 +3546,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|
|
3461
3546
|
size_t ids_nb_ct6 = ids->nb[1];
|
|
3462
3547
|
size_t ids_nb_ct7 = ids->nb[0];
|
|
3463
3548
|
|
|
3464
|
-
|
|
3465
|
-
|
|
3549
|
+
cgh.parallel_for(
|
|
3550
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
3551
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
3466
3552
|
k_copy_src1_to_contiguous(
|
|
3467
3553
|
src1_original, src1_contiguous_get,
|
|
3468
3554
|
dev_cur_src1_row_get,
|
|
@@ -3491,16 +3577,17 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|
|
3491
3577
|
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
|
3492
3578
|
|
|
3493
3579
|
{
|
|
3494
|
-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0,
|
|
3580
|
+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
|
|
3495
3581
|
sycl::range<3> grid_dims(1, 1, num_src1_rows);
|
|
3496
|
-
|
|
3582
|
+
stream->submit([&](sycl::handler &cgh) {
|
|
3497
3583
|
const char *__restrict dst_contiguous_get =
|
|
3498
3584
|
dst_contiguous.get();
|
|
3499
3585
|
const mmid_row_mapping *__restrict dev_row_mapping_get =
|
|
3500
3586
|
dev_row_mapping.get();
|
|
3501
3587
|
|
|
3502
|
-
|
|
3503
|
-
|
|
3588
|
+
cgh.parallel_for(
|
|
3589
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
3590
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
3504
3591
|
k_copy_dst_from_contiguous(dst_original,
|
|
3505
3592
|
dst_contiguous_get,
|
|
3506
3593
|
dev_row_mapping_get,
|
|
@@ -3549,6 +3636,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|
|
3549
3636
|
ggml_sycl_op_sum_rows(ctx, dst);
|
|
3550
3637
|
}
|
|
3551
3638
|
|
|
3639
|
+
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
3640
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
3641
|
+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
|
3642
|
+
ggml_sycl_op_mean(ctx, dst);
|
|
3643
|
+
}
|
|
3644
|
+
|
|
3552
3645
|
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
3553
3646
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
3554
3647
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
|
@@ -3600,9 +3693,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3600
3693
|
case GGML_OP_REPEAT:
|
|
3601
3694
|
ggml_sycl_repeat(ctx, dst);
|
|
3602
3695
|
break;
|
|
3696
|
+
case GGML_OP_REPEAT_BACK:
|
|
3697
|
+
ggml_sycl_repeat_back(ctx, dst);
|
|
3698
|
+
break;
|
|
3603
3699
|
case GGML_OP_GET_ROWS:
|
|
3604
3700
|
ggml_sycl_get_rows(ctx, dst);
|
|
3605
3701
|
break;
|
|
3702
|
+
case GGML_OP_SET:
|
|
3703
|
+
ggml_sycl_op_set(ctx, dst);
|
|
3704
|
+
break;
|
|
3705
|
+
case GGML_OP_SET_ROWS:
|
|
3706
|
+
ggml_sycl_op_set_rows(ctx, dst);
|
|
3707
|
+
break;
|
|
3606
3708
|
case GGML_OP_DUP:
|
|
3607
3709
|
ggml_sycl_dup(ctx, dst);
|
|
3608
3710
|
break;
|
|
@@ -3610,9 +3712,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3610
3712
|
case GGML_OP_ADD1: // TODO: more efficient implementation
|
|
3611
3713
|
ggml_sycl_add(ctx, dst);
|
|
3612
3714
|
break;
|
|
3715
|
+
case GGML_OP_ADD_ID:
|
|
3716
|
+
ggml_sycl_add_id(ctx, dst);
|
|
3717
|
+
break;
|
|
3613
3718
|
case GGML_OP_SUB:
|
|
3614
3719
|
ggml_sycl_sub(ctx, dst);
|
|
3615
3720
|
break;
|
|
3721
|
+
case GGML_OP_COUNT_EQUAL:
|
|
3722
|
+
ggml_sycl_count_equal(ctx, dst);
|
|
3723
|
+
break;
|
|
3616
3724
|
case GGML_OP_ACC:
|
|
3617
3725
|
ggml_sycl_acc(ctx, dst);
|
|
3618
3726
|
break;
|
|
@@ -3672,6 +3780,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3672
3780
|
case GGML_UNARY_OP_ELU:
|
|
3673
3781
|
ggml_sycl_elu(ctx, dst);
|
|
3674
3782
|
break;
|
|
3783
|
+
case GGML_UNARY_OP_FLOOR:
|
|
3784
|
+
ggml_sycl_floor(ctx, dst);
|
|
3785
|
+
break;
|
|
3786
|
+
case GGML_UNARY_OP_CEIL:
|
|
3787
|
+
ggml_sycl_ceil(ctx, dst);
|
|
3788
|
+
break;
|
|
3789
|
+
case GGML_UNARY_OP_ROUND:
|
|
3790
|
+
ggml_sycl_round(ctx, dst);
|
|
3791
|
+
break;
|
|
3792
|
+
case GGML_UNARY_OP_TRUNC:
|
|
3793
|
+
ggml_sycl_trunc(ctx, dst);
|
|
3794
|
+
break;
|
|
3675
3795
|
default:
|
|
3676
3796
|
return false;
|
|
3677
3797
|
}
|
|
@@ -3687,6 +3807,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3687
3807
|
case GGML_GLU_OP_SWIGLU:
|
|
3688
3808
|
ggml_sycl_swiglu(ctx, dst);
|
|
3689
3809
|
break;
|
|
3810
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
3811
|
+
ggml_sycl_swiglu_oai(ctx, dst);
|
|
3812
|
+
break;
|
|
3813
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
3814
|
+
ggml_sycl_geglu_erf(ctx, dst);
|
|
3815
|
+
break;
|
|
3816
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
3817
|
+
ggml_sycl_geglu_quick(ctx, dst);
|
|
3818
|
+
break;
|
|
3690
3819
|
default:
|
|
3691
3820
|
return false;
|
|
3692
3821
|
}
|
|
@@ -3700,6 +3829,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3700
3829
|
case GGML_OP_CONCAT:
|
|
3701
3830
|
ggml_sycl_op_concat(ctx, dst);
|
|
3702
3831
|
break;
|
|
3832
|
+
case GGML_OP_PAD_REFLECT_1D:
|
|
3833
|
+
ggml_sycl_op_pad_reflect_1d(ctx,dst);
|
|
3834
|
+
break;
|
|
3703
3835
|
case GGML_OP_UPSCALE:
|
|
3704
3836
|
ggml_sycl_upscale(ctx, dst);
|
|
3705
3837
|
break;
|
|
@@ -3709,6 +3841,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3709
3841
|
case GGML_OP_LEAKY_RELU:
|
|
3710
3842
|
ggml_sycl_leaky_relu(ctx, dst);
|
|
3711
3843
|
break;
|
|
3844
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
3845
|
+
ggml_sycl_rms_norm_back(ctx, dst);
|
|
3846
|
+
break;
|
|
3712
3847
|
case GGML_OP_RMS_NORM:
|
|
3713
3848
|
ggml_sycl_rms_norm(ctx, dst);
|
|
3714
3849
|
break;
|
|
@@ -3768,6 +3903,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3768
3903
|
case GGML_OP_SOFT_MAX:
|
|
3769
3904
|
ggml_sycl_op_soft_max(ctx, dst);
|
|
3770
3905
|
break;
|
|
3906
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
3907
|
+
ggml_sycl_op_soft_max_back(ctx, dst);
|
|
3908
|
+
break;
|
|
3771
3909
|
case GGML_OP_ROPE:
|
|
3772
3910
|
ggml_sycl_rope(ctx, dst);
|
|
3773
3911
|
break;
|
|
@@ -3783,6 +3921,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3783
3921
|
case GGML_OP_SUM_ROWS:
|
|
3784
3922
|
ggml_sycl_sum_rows(ctx, dst);
|
|
3785
3923
|
break;
|
|
3924
|
+
case GGML_OP_MEAN:
|
|
3925
|
+
ggml_sycl_mean(ctx, dst);
|
|
3926
|
+
break;
|
|
3786
3927
|
case GGML_OP_ARGSORT:
|
|
3787
3928
|
ggml_sycl_argsort(ctx, dst);
|
|
3788
3929
|
break;
|
|
@@ -3798,6 +3939,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3798
3939
|
case GGML_OP_GATED_LINEAR_ATTN:
|
|
3799
3940
|
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
|
3800
3941
|
break;
|
|
3942
|
+
case GGML_OP_SSM_CONV:
|
|
3943
|
+
ggml_sycl_ssm_conv(ctx, dst);
|
|
3944
|
+
break;
|
|
3945
|
+
case GGML_OP_ROLL:
|
|
3946
|
+
ggml_sycl_roll(ctx, dst);
|
|
3947
|
+
break;
|
|
3948
|
+
case GGML_OP_ARANGE:
|
|
3949
|
+
ggml_sycl_arange(ctx, dst);
|
|
3950
|
+
break;
|
|
3801
3951
|
default:
|
|
3802
3952
|
return false;
|
|
3803
3953
|
}
|
|
@@ -3805,6 +3955,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3805
3955
|
return true;
|
|
3806
3956
|
} catch (sycl::exception & e) {
|
|
3807
3957
|
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
|
3958
|
+
std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
|
|
3808
3959
|
std::exit(1);
|
|
3809
3960
|
}
|
|
3810
3961
|
|
|
@@ -3999,6 +4150,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
|
|
|
3999
4150
|
GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
|
|
4000
4151
|
ggml_op_name(node_op));
|
|
4001
4152
|
return false;
|
|
4153
|
+
case GGML_OP_MUL_MAT:
|
|
4154
|
+
// We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
|
|
4155
|
+
// as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
|
|
4156
|
+
// in reordering.
|
|
4157
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
4158
|
+
GGML_LOG_INFO(
|
|
4159
|
+
"%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
|
|
4160
|
+
"oneAPI async memory allocation extension "
|
|
4161
|
+
"%s\n",
|
|
4162
|
+
__func__, ggml_op_name(node_op));
|
|
4163
|
+
return false;
|
|
4164
|
+
}
|
|
4002
4165
|
}
|
|
4003
4166
|
}
|
|
4004
4167
|
return true;
|
|
@@ -4100,6 +4263,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
|
|
|
4100
4263
|
/* .graph_compute = */ ggml_backend_sycl_graph_compute,
|
|
4101
4264
|
/* .event_record = */ ggml_backend_sycl_event_record,
|
|
4102
4265
|
/* .event_wait = */ ggml_backend_sycl_event_wait,
|
|
4266
|
+
/* .graph_optimize = */ NULL,
|
|
4103
4267
|
};
|
|
4104
4268
|
|
|
4105
4269
|
static ggml_guid_t ggml_backend_sycl_guid() {
|
|
@@ -4122,6 +4286,7 @@ struct ggml_backend_sycl_device_context {
|
|
|
4122
4286
|
int device;
|
|
4123
4287
|
std::string name;
|
|
4124
4288
|
std::string description;
|
|
4289
|
+
int op_offload_min_batch_size;
|
|
4125
4290
|
};
|
|
4126
4291
|
|
|
4127
4292
|
static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
|
|
@@ -4192,6 +4357,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
|
|
|
4192
4357
|
}
|
|
4193
4358
|
|
|
4194
4359
|
static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
4360
|
+
ggml_backend_sycl_device_context *sycl_ctx =
|
|
4361
|
+
(ggml_backend_sycl_device_context *)dev->context;
|
|
4362
|
+
int device = sycl_ctx->device;
|
|
4195
4363
|
switch (op->op) {
|
|
4196
4364
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
4197
4365
|
{
|
|
@@ -4204,21 +4372,26 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4204
4372
|
}
|
|
4205
4373
|
case GGML_OP_UNARY:
|
|
4206
4374
|
switch (ggml_get_unary_op(op)) {
|
|
4375
|
+
case GGML_UNARY_OP_SGN:
|
|
4376
|
+
case GGML_UNARY_OP_ABS:
|
|
4207
4377
|
case GGML_UNARY_OP_NEG:
|
|
4208
4378
|
case GGML_UNARY_OP_STEP:
|
|
4379
|
+
case GGML_UNARY_OP_RELU:
|
|
4380
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
|
4381
|
+
case GGML_UNARY_OP_TANH:
|
|
4209
4382
|
case GGML_UNARY_OP_GELU:
|
|
4210
4383
|
case GGML_UNARY_OP_SILU:
|
|
4211
|
-
case GGML_UNARY_OP_RELU:
|
|
4212
4384
|
case GGML_UNARY_OP_SIGMOID:
|
|
4213
|
-
case GGML_UNARY_OP_HARDSIGMOID:
|
|
4214
4385
|
case GGML_UNARY_OP_HARDSWISH:
|
|
4215
4386
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
4216
4387
|
case GGML_UNARY_OP_GELU_ERF:
|
|
4217
|
-
case GGML_UNARY_OP_TANH:
|
|
4218
4388
|
case GGML_UNARY_OP_EXP:
|
|
4219
|
-
case GGML_UNARY_OP_SGN:
|
|
4220
|
-
case GGML_UNARY_OP_ABS:
|
|
4221
4389
|
case GGML_UNARY_OP_ELU:
|
|
4390
|
+
return true;
|
|
4391
|
+
case GGML_UNARY_OP_FLOOR:
|
|
4392
|
+
case GGML_UNARY_OP_CEIL:
|
|
4393
|
+
case GGML_UNARY_OP_ROUND:
|
|
4394
|
+
case GGML_UNARY_OP_TRUNC:
|
|
4222
4395
|
#if defined (GGML_SYCL_F16)
|
|
4223
4396
|
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
|
|
4224
4397
|
#else
|
|
@@ -4232,6 +4405,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4232
4405
|
case GGML_GLU_OP_REGLU:
|
|
4233
4406
|
case GGML_GLU_OP_GEGLU:
|
|
4234
4407
|
case GGML_GLU_OP_SWIGLU:
|
|
4408
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
4409
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
4410
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
4235
4411
|
return ggml_is_contiguous_1(op->src[0]);
|
|
4236
4412
|
default:
|
|
4237
4413
|
return false;
|
|
@@ -4240,15 +4416,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4240
4416
|
case GGML_OP_MUL_MAT:
|
|
4241
4417
|
case GGML_OP_MUL_MAT_ID:
|
|
4242
4418
|
{
|
|
4243
|
-
struct ggml_tensor * a;
|
|
4244
|
-
struct ggml_tensor * b;
|
|
4245
|
-
|
|
4246
|
-
a = op->src[0];
|
|
4247
|
-
b = op->src[1];
|
|
4248
|
-
} else {
|
|
4249
|
-
a = op->src[2];
|
|
4250
|
-
b = op->src[1];
|
|
4251
|
-
}
|
|
4419
|
+
struct ggml_tensor * a = op->src[0];
|
|
4420
|
+
struct ggml_tensor * b = op->src[1];
|
|
4421
|
+
|
|
4252
4422
|
if (a->ne[3] != b->ne[3]) {
|
|
4253
4423
|
return false;
|
|
4254
4424
|
}
|
|
@@ -4263,7 +4433,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4263
4433
|
}
|
|
4264
4434
|
}
|
|
4265
4435
|
ggml_type src0_type = op->src[0]->type;
|
|
4266
|
-
if (src0_type == GGML_TYPE_BF16) {
|
|
4436
|
+
if (src0_type == GGML_TYPE_BF16 ) {
|
|
4437
|
+
// TODO: support GGML_TYPE_BF16
|
|
4438
|
+
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
|
|
4439
|
+
return false;
|
|
4440
|
+
}
|
|
4441
|
+
|
|
4442
|
+
// TODO: The configuration below needs more work to be supported with oneDNN
|
|
4443
|
+
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
|
|
4444
|
+
a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
|
|
4445
|
+
return false;
|
|
4446
|
+
}
|
|
4447
|
+
|
|
4448
|
+
// TODO: This specific configuration can fail with oneDNN and needs more debugging
|
|
4449
|
+
if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
|
|
4450
|
+
a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
|
|
4267
4451
|
return false;
|
|
4268
4452
|
}
|
|
4269
4453
|
return true;
|
|
@@ -4285,6 +4469,20 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4285
4469
|
return false;
|
|
4286
4470
|
}
|
|
4287
4471
|
}
|
|
4472
|
+
case GGML_OP_SET:
|
|
4473
|
+
return (op->type == GGML_TYPE_F32) &&
|
|
4474
|
+
(op->src[0] && op->src[1]) &&
|
|
4475
|
+
(op->src[0]->type == GGML_TYPE_F32) &&
|
|
4476
|
+
(op->src[1]->type == GGML_TYPE_F32);
|
|
4477
|
+
|
|
4478
|
+
case GGML_OP_SET_ROWS:
|
|
4479
|
+
{
|
|
4480
|
+
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
|
4481
|
+
op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
|
|
4482
|
+
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
|
|
4483
|
+
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
|
|
4484
|
+
}
|
|
4485
|
+
break;
|
|
4288
4486
|
case GGML_OP_CPY:
|
|
4289
4487
|
{
|
|
4290
4488
|
ggml_type src0_type = op->src[0]->type;
|
|
@@ -4354,11 +4552,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4354
4552
|
}
|
|
4355
4553
|
return false;
|
|
4356
4554
|
}
|
|
4357
|
-
case
|
|
4555
|
+
case GGML_OP_REPEAT_BACK:
|
|
4358
4556
|
{
|
|
4359
4557
|
ggml_type src0_type = op->src[0]->type;
|
|
4360
|
-
return src0_type
|
|
4558
|
+
return src0_type == GGML_TYPE_F32;
|
|
4361
4559
|
}
|
|
4560
|
+
case GGML_OP_CONCAT:
|
|
4362
4561
|
case GGML_OP_DUP:
|
|
4363
4562
|
case GGML_OP_ARGMAX:
|
|
4364
4563
|
case GGML_OP_NONE:
|
|
@@ -4366,14 +4565,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4366
4565
|
case GGML_OP_VIEW:
|
|
4367
4566
|
case GGML_OP_PERMUTE:
|
|
4368
4567
|
case GGML_OP_TRANSPOSE:
|
|
4369
|
-
return true;
|
|
4370
4568
|
case GGML_OP_ADD:
|
|
4371
4569
|
case GGML_OP_ADD1:
|
|
4570
|
+
case GGML_OP_ADD_ID:
|
|
4372
4571
|
case GGML_OP_SUB:
|
|
4572
|
+
case GGML_OP_COUNT_EQUAL:
|
|
4373
4573
|
case GGML_OP_MUL:
|
|
4374
4574
|
case GGML_OP_DIV:
|
|
4375
4575
|
case GGML_OP_REPEAT:
|
|
4376
4576
|
return true;
|
|
4577
|
+
case GGML_OP_PAD_REFLECT_1D:
|
|
4578
|
+
return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
|
4377
4579
|
case GGML_OP_SQR:
|
|
4378
4580
|
case GGML_OP_SQRT:
|
|
4379
4581
|
case GGML_OP_SIN:
|
|
@@ -4386,35 +4588,62 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4386
4588
|
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
|
4387
4589
|
#endif
|
|
4388
4590
|
case GGML_OP_NORM:
|
|
4389
|
-
case GGML_OP_RMS_NORM:
|
|
4390
4591
|
return true;
|
|
4391
4592
|
case GGML_OP_L2_NORM:
|
|
4392
4593
|
case GGML_OP_GROUP_NORM:
|
|
4393
4594
|
return ggml_is_contiguous(op->src[0]);
|
|
4595
|
+
case GGML_OP_RMS_NORM:
|
|
4596
|
+
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
|
4597
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
4598
|
+
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
|
4394
4599
|
case GGML_OP_SCALE:
|
|
4395
4600
|
return true;
|
|
4396
4601
|
case GGML_OP_CONT:
|
|
4397
4602
|
return op->src[0]->type != GGML_TYPE_BF16;
|
|
4398
4603
|
case GGML_OP_DIAG_MASK_INF:
|
|
4604
|
+
return true;
|
|
4399
4605
|
case GGML_OP_SOFT_MAX:
|
|
4400
4606
|
return true;
|
|
4607
|
+
case GGML_OP_SOFT_MAX_BACK: {
|
|
4608
|
+
float max_bias = 0.0f;
|
|
4609
|
+
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
|
|
4610
|
+
return max_bias == 0.0f;
|
|
4611
|
+
}
|
|
4401
4612
|
case GGML_OP_ROPE:
|
|
4402
4613
|
case GGML_OP_IM2COL:
|
|
4403
4614
|
return true;
|
|
4404
4615
|
case GGML_OP_UPSCALE:
|
|
4405
|
-
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
4406
|
-
case GGML_OP_POOL_2D:
|
|
4616
|
+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
|
4407
4617
|
case GGML_OP_SUM:
|
|
4408
4618
|
case GGML_OP_SUM_ROWS:
|
|
4619
|
+
case GGML_OP_MEAN:
|
|
4620
|
+
return ggml_is_contiguous(op->src[0]);
|
|
4409
4621
|
case GGML_OP_ARGSORT:
|
|
4622
|
+
return op->src[0]->ne[0] * sizeof(int) <=
|
|
4623
|
+
ggml_sycl_info().devices[device].smpbo;
|
|
4624
|
+
case GGML_OP_POOL_2D:
|
|
4410
4625
|
case GGML_OP_ACC:
|
|
4626
|
+
return true;
|
|
4411
4627
|
case GGML_OP_PAD:
|
|
4628
|
+
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
|
4629
|
+
if (ggml_get_op_params_i32(op, 8) != 0) {
|
|
4630
|
+
return false;
|
|
4631
|
+
}
|
|
4632
|
+
return ggml_is_contiguous(op->src[0]);
|
|
4412
4633
|
case GGML_OP_LEAKY_RELU:
|
|
4413
4634
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
4414
4635
|
case GGML_OP_RWKV_WKV6:
|
|
4415
4636
|
case GGML_OP_RWKV_WKV7:
|
|
4416
4637
|
case GGML_OP_GATED_LINEAR_ATTN:
|
|
4417
4638
|
return true;
|
|
4639
|
+
case GGML_OP_SSM_CONV:
|
|
4640
|
+
return op->type == GGML_TYPE_F32 &&
|
|
4641
|
+
op->src[0]->type == GGML_TYPE_F32 &&
|
|
4642
|
+
op->src[1]->type == GGML_TYPE_F32;
|
|
4643
|
+
case GGML_OP_ROLL:
|
|
4644
|
+
return op->type == GGML_TYPE_F32;
|
|
4645
|
+
case GGML_OP_ARANGE:
|
|
4646
|
+
return op->type == GGML_TYPE_F32;
|
|
4418
4647
|
default:
|
|
4419
4648
|
return false;
|
|
4420
4649
|
}
|
|
@@ -4446,9 +4675,8 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
|
|
|
4446
4675
|
}
|
|
4447
4676
|
|
|
4448
4677
|
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
4449
|
-
|
|
4450
|
-
return get_op_batch_size(op) >=
|
|
4451
|
-
GGML_UNUSED(dev);
|
|
4678
|
+
ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
|
|
4679
|
+
return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
|
|
4452
4680
|
}
|
|
4453
4681
|
|
|
4454
4682
|
static ggml_backend_event_t
|
|
@@ -4571,6 +4799,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
|
|
|
4571
4799
|
std::lock_guard<std::mutex> lock(mutex);
|
|
4572
4800
|
if (!initialized) {
|
|
4573
4801
|
ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
|
|
4802
|
+
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
|
|
4574
4803
|
|
|
4575
4804
|
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
|
|
4576
4805
|
ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
|
|
@@ -4584,6 +4813,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
|
|
|
4584
4813
|
prop, dpct::dev_mgr::instance().get_device(i))));
|
|
4585
4814
|
|
|
4586
4815
|
dev_ctx->description = prop.get_name();
|
|
4816
|
+
dev_ctx->op_offload_min_batch_size = min_batch_size;
|
|
4587
4817
|
|
|
4588
4818
|
ggml_backend_dev_t dev = new ggml_backend_device {
|
|
4589
4819
|
/* .iface = */ ggml_backend_sycl_device_interface,
|
|
@@ -4619,10 +4849,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) {
|
|
|
4619
4849
|
};
|
|
4620
4850
|
|
|
4621
4851
|
ggml_backend_t sycl_backend = new ggml_backend {
|
|
4622
|
-
/* .guid
|
|
4623
|
-
/* .
|
|
4624
|
-
/* .device
|
|
4625
|
-
/* .context
|
|
4852
|
+
/* .guid = */ ggml_backend_sycl_guid(),
|
|
4853
|
+
/* .iface = */ ggml_backend_sycl_interface,
|
|
4854
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
|
|
4855
|
+
/* .context = */ ctx
|
|
4626
4856
|
};
|
|
4627
4857
|
|
|
4628
4858
|
return sycl_backend;
|