whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -30,22 +30,32 @@
|
|
|
30
30
|
#include <regex>
|
|
31
31
|
|
|
32
32
|
#include <sycl/sycl.hpp>
|
|
33
|
+
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
|
34
|
+
# include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
|
|
35
|
+
#endif
|
|
33
36
|
#include <sycl/half_type.hpp>
|
|
34
37
|
|
|
38
|
+
#include "ggml.h"
|
|
35
39
|
#include "ggml-sycl.h"
|
|
36
40
|
#include "ggml-impl.h"
|
|
37
41
|
#include "ggml-backend-impl.h"
|
|
38
42
|
|
|
43
|
+
#include "ggml-sycl/add-id.hpp"
|
|
39
44
|
#include "ggml-sycl/backend.hpp"
|
|
40
45
|
#include "ggml-sycl/common.hpp"
|
|
41
46
|
#include "ggml-sycl/element_wise.hpp"
|
|
42
|
-
#include "ggml-sycl/
|
|
47
|
+
#include "ggml-sycl/gated_delta_net.hpp"
|
|
43
48
|
#include "ggml-sycl/gemm.hpp"
|
|
44
|
-
#include "ggml-sycl/set_rows.hpp"
|
|
45
|
-
#include "ggml-sycl/sycl_hw.hpp"
|
|
46
49
|
#include "ggml-sycl/getrows.hpp"
|
|
50
|
+
#include "ggml-sycl/norm.hpp"
|
|
51
|
+
#include "ggml-sycl/presets.hpp"
|
|
47
52
|
#include "ggml-sycl/quantize.hpp"
|
|
48
|
-
#include "ggml.
|
|
53
|
+
#include "ggml-sycl/repeat_back.hpp"
|
|
54
|
+
#include "ggml-sycl/set_rows.hpp"
|
|
55
|
+
#include "ggml-sycl/set.hpp"
|
|
56
|
+
#include "ggml-sycl/ssm_conv.hpp"
|
|
57
|
+
#include "ggml-sycl/sycl_hw.hpp"
|
|
58
|
+
|
|
49
59
|
|
|
50
60
|
static bool g_sycl_loaded = false;
|
|
51
61
|
int g_ggml_sycl_debug = 0;
|
|
@@ -53,6 +63,9 @@ int g_ggml_sycl_disable_optimize = 0;
|
|
|
53
63
|
int g_ggml_sycl_disable_graph = 0;
|
|
54
64
|
int g_ggml_sycl_disable_dnn = 0;
|
|
55
65
|
int g_ggml_sycl_prioritize_dmmv = 0;
|
|
66
|
+
int g_ggml_sycl_use_async_mem_op = 0;
|
|
67
|
+
int g_ggml_sycl_enable_flash_attention = 1;
|
|
68
|
+
|
|
56
69
|
|
|
57
70
|
static ggml_sycl_device_info ggml_sycl_init() {
|
|
58
71
|
ggml_sycl_device_info info = {};
|
|
@@ -85,8 +98,14 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
|
|
85
98
|
|
|
86
99
|
info.devices[i].cc =
|
|
87
100
|
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
|
101
|
+
info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
|
|
88
102
|
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
|
103
|
+
info.devices[i].smpbo = prop.get_local_mem_size();
|
|
104
|
+
info.devices[i].warp_size = WARP_SIZE;
|
|
105
|
+
|
|
89
106
|
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
|
107
|
+
info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
|
|
108
|
+
|
|
90
109
|
}
|
|
91
110
|
|
|
92
111
|
for (int id = 0; id < info.device_count; ++id) {
|
|
@@ -199,7 +218,37 @@ static void ggml_check_sycl() try {
|
|
|
199
218
|
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
|
200
219
|
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
|
201
220
|
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
|
221
|
+
|
|
222
|
+
#ifdef SYCL_FLASH_ATTN
|
|
223
|
+
g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
|
|
224
|
+
#else
|
|
225
|
+
g_ggml_sycl_enable_flash_attention = 0;
|
|
226
|
+
#endif
|
|
227
|
+
|
|
202
228
|
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
|
229
|
+
|
|
230
|
+
GGML_LOG_INFO("Build with Macros:\n");
|
|
231
|
+
#if defined(GGML_SYCL_FORCE_MMQ)
|
|
232
|
+
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
|
233
|
+
#else
|
|
234
|
+
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
|
|
235
|
+
#endif
|
|
236
|
+
#if defined(GGML_SYCL_F16)
|
|
237
|
+
GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
|
|
238
|
+
#else
|
|
239
|
+
GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
|
|
240
|
+
#endif
|
|
241
|
+
#if defined(GGML_SYCL_GRAPH)
|
|
242
|
+
GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n");
|
|
243
|
+
#else
|
|
244
|
+
GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n");
|
|
245
|
+
#endif
|
|
246
|
+
#if defined(GGML_SYCL_DNNL)
|
|
247
|
+
GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n");
|
|
248
|
+
#else
|
|
249
|
+
GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n");
|
|
250
|
+
#endif
|
|
251
|
+
|
|
203
252
|
GGML_LOG_INFO("Running with Environment Variables:\n");
|
|
204
253
|
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
|
205
254
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
|
@@ -214,16 +263,12 @@ static void ggml_check_sycl() try {
|
|
|
214
263
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
|
215
264
|
#endif
|
|
216
265
|
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
|
217
|
-
|
|
218
|
-
#
|
|
219
|
-
GGML_LOG_INFO("
|
|
220
|
-
#else
|
|
221
|
-
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
|
|
222
|
-
#endif
|
|
223
|
-
#if defined(GGML_SYCL_F16)
|
|
224
|
-
GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
|
|
266
|
+
|
|
267
|
+
#ifdef SYCL_FLASH_ATTN
|
|
268
|
+
GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention);
|
|
225
269
|
#else
|
|
226
|
-
GGML_LOG_INFO("
|
|
270
|
+
GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n",
|
|
271
|
+
g_ggml_sycl_enable_flash_attention);
|
|
227
272
|
#endif
|
|
228
273
|
|
|
229
274
|
/* NOT REMOVE, keep it for next optimize for XMX.
|
|
@@ -233,7 +278,20 @@ static void ggml_check_sycl() try {
|
|
|
233
278
|
fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
|
|
234
279
|
#endif
|
|
235
280
|
*/
|
|
236
|
-
|
|
281
|
+
// Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
|
|
282
|
+
// properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
|
|
283
|
+
// other places.
|
|
284
|
+
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
|
285
|
+
g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
|
|
286
|
+
if (g_ggml_sycl_use_async_mem_op) {
|
|
287
|
+
for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
|
|
288
|
+
if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
|
|
289
|
+
g_ggml_sycl_use_async_mem_op = 0;
|
|
290
|
+
break;
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
#endif
|
|
237
295
|
if (CHECK_TRY_ERROR(g_all_sycl_device_count =
|
|
238
296
|
dpct::dev_mgr::instance().device_count()) != 0) {
|
|
239
297
|
initialized = true;
|
|
@@ -1132,13 +1190,28 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
|
|
|
1132
1190
|
GGML_UNUSED(buft);
|
|
1133
1191
|
}
|
|
1134
1192
|
|
|
1193
|
+
inline void * aligned_malloc_host(size_t alignment, size_t size) {
|
|
1194
|
+
#ifdef _WIN32
|
|
1195
|
+
return _aligned_malloc(size, alignment);
|
|
1196
|
+
#else
|
|
1197
|
+
return aligned_alloc(alignment, size);
|
|
1198
|
+
#endif
|
|
1199
|
+
}
|
|
1200
|
+
|
|
1201
|
+
inline void free_aligned_mem_host(void * memblock) {
|
|
1202
|
+
#ifdef _WIN32
|
|
1203
|
+
_aligned_free(memblock);
|
|
1204
|
+
#else
|
|
1205
|
+
free(memblock);
|
|
1206
|
+
#endif
|
|
1207
|
+
}
|
|
1208
|
+
|
|
1135
1209
|
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
1136
|
-
|
|
1210
|
+
free_aligned_mem_host((void *)buffer->context);
|
|
1137
1211
|
}
|
|
1138
1212
|
|
|
1139
1213
|
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
1140
|
-
void * ptr =
|
|
1141
|
-
|
|
1214
|
+
void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);
|
|
1142
1215
|
if (ptr == nullptr) {
|
|
1143
1216
|
// fallback to cpu buffer
|
|
1144
1217
|
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
|
@@ -1511,60 +1584,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
|
|
|
1511
1584
|
template <ggml_sort_order order>
|
|
1512
1585
|
__dpct_inline__ static void
|
|
1513
1586
|
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
|
|
1514
|
-
const sycl::nd_item<3> &item_ct1,
|
|
1587
|
+
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
|
|
1588
|
+
uint8_t *dpct_local) {
|
|
1515
1589
|
// bitonic sort
|
|
1516
|
-
int
|
|
1590
|
+
int col_index = item_ct1.get_local_id(2);
|
|
1517
1591
|
int row = item_ct1.get_group(1);
|
|
1518
1592
|
|
|
1519
|
-
|
|
1520
|
-
|
|
1593
|
+
for (int i = 0; i < tasks_per_thread; i++) {
|
|
1594
|
+
int col = col_index * tasks_per_thread + i;
|
|
1595
|
+
if (col >= ncols_pad) {
|
|
1596
|
+
return;
|
|
1597
|
+
}
|
|
1521
1598
|
}
|
|
1522
1599
|
|
|
1523
1600
|
const float * x_row = x + row * ncols;
|
|
1524
1601
|
auto dst_row = (int *)dpct_local;
|
|
1525
1602
|
|
|
1526
1603
|
// initialize indices
|
|
1527
|
-
|
|
1604
|
+
for (int i=0;i<tasks_per_thread;i++){
|
|
1605
|
+
int col = col_index*tasks_per_thread+i;
|
|
1606
|
+
dst_row[col] = col;
|
|
1607
|
+
}
|
|
1528
1608
|
|
|
1529
1609
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1530
1610
|
|
|
1531
1611
|
for (int k = 2; k <= ncols_pad; k *= 2) {
|
|
1532
1612
|
for (int j = k / 2; j > 0; j /= 2) {
|
|
1533
|
-
int
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1613
|
+
for (int i = 0; i < tasks_per_thread; i++) {
|
|
1614
|
+
int col = col_index * tasks_per_thread + i;
|
|
1615
|
+
int ixj = col ^ j;
|
|
1616
|
+
if (ixj > col) {
|
|
1617
|
+
if ((col & k) == 0) {
|
|
1618
|
+
if (dst_row[col] >= ncols ||
|
|
1619
|
+
(dst_row[ixj] < ncols &&
|
|
1620
|
+
(order == GGML_SORT_ORDER_ASC
|
|
1621
|
+
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
|
|
1622
|
+
: x_row[dst_row[col]] <
|
|
1623
|
+
x_row[dst_row[ixj]]))) {
|
|
1624
|
+
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
|
1625
|
+
}
|
|
1626
|
+
} else {
|
|
1627
|
+
if (dst_row[ixj] >= ncols ||
|
|
1628
|
+
(dst_row[col] < ncols &&
|
|
1629
|
+
(order == GGML_SORT_ORDER_ASC
|
|
1630
|
+
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
|
|
1631
|
+
: x_row[dst_row[col]] >
|
|
1632
|
+
x_row[dst_row[ixj]]))) {
|
|
1633
|
+
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
|
1634
|
+
}
|
|
1550
1635
|
}
|
|
1551
1636
|
}
|
|
1637
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1552
1638
|
}
|
|
1553
|
-
/*
|
|
1554
|
-
DPCT1118:1: SYCL group functions and algorithms must be encountered
|
|
1555
|
-
in converged control flow. You may need to adjust the code.
|
|
1556
|
-
*/
|
|
1557
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1558
1639
|
}
|
|
1559
1640
|
}
|
|
1560
1641
|
|
|
1561
1642
|
// copy the result to dst without the padding
|
|
1562
|
-
|
|
1563
|
-
|
|
1643
|
+
for (int i = 0; i < tasks_per_thread; i++) {
|
|
1644
|
+
int col = col_index * tasks_per_thread + i;
|
|
1645
|
+
if (col < ncols) {
|
|
1646
|
+
dst[row * ncols + col] = dst_row[col];
|
|
1647
|
+
}
|
|
1564
1648
|
}
|
|
1565
1649
|
}
|
|
1566
1650
|
|
|
1567
|
-
|
|
1568
1651
|
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
|
|
1569
1652
|
const sycl::nd_item<3> &item_ct1) {
|
|
1570
1653
|
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
@@ -1737,13 +1820,23 @@ static int next_power_of_2(int x) {
|
|
|
1737
1820
|
|
|
1738
1821
|
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
1739
1822
|
const int nrows, ggml_sort_order order,
|
|
1740
|
-
queue_ptr stream) {
|
|
1823
|
+
queue_ptr stream, int device) {
|
|
1741
1824
|
// bitonic sort requires ncols to be power of 2
|
|
1742
1825
|
const int ncols_pad = next_power_of_2(ncols);
|
|
1743
1826
|
|
|
1744
|
-
|
|
1827
|
+
int nth = 1;
|
|
1828
|
+
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
1829
|
+
while (nth < ncols_pad && nth < max_block_size)
|
|
1830
|
+
nth *= 2;
|
|
1831
|
+
if (nth > max_block_size)
|
|
1832
|
+
nth = max_block_size;
|
|
1833
|
+
|
|
1834
|
+
const int tasks_per_thread = ncols_pad / nth;
|
|
1835
|
+
|
|
1836
|
+
const sycl::range<3> block_dims(1, 1, nth);
|
|
1745
1837
|
const sycl::range<3> block_nums(1, nrows, 1);
|
|
1746
1838
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
|
1839
|
+
GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
|
|
1747
1840
|
|
|
1748
1841
|
if (order == GGML_SORT_ORDER_ASC) {
|
|
1749
1842
|
stream->submit([&](sycl::handler &cgh) {
|
|
@@ -1754,8 +1847,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
|
1754
1847
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1755
1848
|
[=](sycl::nd_item<3> item_ct1) {
|
|
1756
1849
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
|
1757
|
-
x, dst, ncols, ncols_pad, item_ct1,
|
|
1758
|
-
dpct_local_acc_ct1
|
|
1850
|
+
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
|
1851
|
+
dpct_local_acc_ct1
|
|
1852
|
+
.get_multi_ptr<sycl::access::decorated::no>()
|
|
1759
1853
|
.get());
|
|
1760
1854
|
});
|
|
1761
1855
|
});
|
|
@@ -1768,8 +1862,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
|
1768
1862
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
1769
1863
|
[=](sycl::nd_item<3> item_ct1) {
|
|
1770
1864
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
|
1771
|
-
x, dst, ncols, ncols_pad, item_ct1,
|
|
1772
|
-
dpct_local_acc_ct1
|
|
1865
|
+
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
|
1866
|
+
dpct_local_acc_ct1
|
|
1867
|
+
.get_multi_ptr<sycl::access::decorated::no>()
|
|
1773
1868
|
.get());
|
|
1774
1869
|
});
|
|
1775
1870
|
});
|
|
@@ -1778,6 +1873,110 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
|
1778
1873
|
}
|
|
1779
1874
|
}
|
|
1780
1875
|
|
|
1876
|
+
static void top_k_f32_sycl(
|
|
1877
|
+
const float * src,
|
|
1878
|
+
int32_t * dst_indices,
|
|
1879
|
+
const int64_t ncols,
|
|
1880
|
+
const int64_t nrows,
|
|
1881
|
+
const int k,
|
|
1882
|
+
dpct::queue_ptr main_stream
|
|
1883
|
+
) {
|
|
1884
|
+
const int block_size = 128;
|
|
1885
|
+
|
|
1886
|
+
const sycl::range<1> block_dims(block_size);
|
|
1887
|
+
const sycl::range<1> grid_dims(nrows);
|
|
1888
|
+
|
|
1889
|
+
main_stream->submit([&](sycl::handler &cgh) {
|
|
1890
|
+
sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh);
|
|
1891
|
+
sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh);
|
|
1892
|
+
|
|
1893
|
+
cgh.parallel_for(
|
|
1894
|
+
sycl::nd_range<1>(grid_dims * block_dims, block_dims),
|
|
1895
|
+
[=](sycl::nd_item<1> item_ct1) {
|
|
1896
|
+
const int row = item_ct1.get_group(0);
|
|
1897
|
+
const int tid = item_ct1.get_local_id(0);
|
|
1898
|
+
|
|
1899
|
+
if (row >= nrows) return;
|
|
1900
|
+
|
|
1901
|
+
const float * src_row = src + row * ncols;
|
|
1902
|
+
int32_t * dst_idx_row = dst_indices + row * k;
|
|
1903
|
+
|
|
1904
|
+
float local_vals[32];
|
|
1905
|
+
int local_idx[32];
|
|
1906
|
+
|
|
1907
|
+
for (int i = 0; i < k; i++) {
|
|
1908
|
+
local_vals[i] = -FLT_MAX;
|
|
1909
|
+
local_idx[i] = -1;
|
|
1910
|
+
}
|
|
1911
|
+
|
|
1912
|
+
for (int col = tid; col < ncols; col += block_size) {
|
|
1913
|
+
float val = src_row[col];
|
|
1914
|
+
|
|
1915
|
+
if (val > local_vals[k-1]) {
|
|
1916
|
+
int pos = k - 1;
|
|
1917
|
+
while (pos > 0 && val > local_vals[pos - 1]) {
|
|
1918
|
+
pos--;
|
|
1919
|
+
}
|
|
1920
|
+
|
|
1921
|
+
for (int i = k - 1; i > pos; i--) {
|
|
1922
|
+
local_vals[i] = local_vals[i - 1];
|
|
1923
|
+
local_idx[i] = local_idx[i - 1];
|
|
1924
|
+
}
|
|
1925
|
+
local_vals[pos] = val;
|
|
1926
|
+
local_idx[pos] = col;
|
|
1927
|
+
}
|
|
1928
|
+
}
|
|
1929
|
+
|
|
1930
|
+
for (int i = 0; i < k; i++) {
|
|
1931
|
+
shared_vals[tid * k + i] = local_vals[i];
|
|
1932
|
+
shared_idx[tid * k + i] = local_idx[i];
|
|
1933
|
+
}
|
|
1934
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
1935
|
+
|
|
1936
|
+
if (tid == 0) {
|
|
1937
|
+
float final_vals[32];
|
|
1938
|
+
int final_idx[32];
|
|
1939
|
+
|
|
1940
|
+
for (int i = 0; i < k; i++) {
|
|
1941
|
+
final_vals[i] = -FLT_MAX;
|
|
1942
|
+
final_idx[i] = -1;
|
|
1943
|
+
}
|
|
1944
|
+
|
|
1945
|
+
for (int t = 0; t < block_size; t++) {
|
|
1946
|
+
for (int i = 0; i < k; i++) {
|
|
1947
|
+
float val = shared_vals[t * k + i];
|
|
1948
|
+
int idx = shared_idx[t * k + i];
|
|
1949
|
+
|
|
1950
|
+
if (val > final_vals[k-1]) {
|
|
1951
|
+
int pos = k - 1;
|
|
1952
|
+
while (pos > 0 && val > final_vals[pos - 1]) {
|
|
1953
|
+
pos--;
|
|
1954
|
+
}
|
|
1955
|
+
|
|
1956
|
+
for (int j = k - 1; j > pos; j--) {
|
|
1957
|
+
final_vals[j] = final_vals[j - 1];
|
|
1958
|
+
final_idx[j] = final_idx[j - 1];
|
|
1959
|
+
}
|
|
1960
|
+
final_vals[pos] = val;
|
|
1961
|
+
final_idx[pos] = idx;
|
|
1962
|
+
}
|
|
1963
|
+
}
|
|
1964
|
+
}
|
|
1965
|
+
|
|
1966
|
+
for (int i = 0; i < k; i++) {
|
|
1967
|
+
dst_idx_row[i] = final_idx[i];
|
|
1968
|
+
}
|
|
1969
|
+
|
|
1970
|
+
if (k > 1) {
|
|
1971
|
+
int32_t temp = dst_idx_row[0];
|
|
1972
|
+
dst_idx_row[0] = dst_idx_row[1];
|
|
1973
|
+
dst_idx_row[1] = temp;
|
|
1974
|
+
}
|
|
1975
|
+
}
|
|
1976
|
+
});
|
|
1977
|
+
});
|
|
1978
|
+
}
|
|
1979
|
+
|
|
1781
1980
|
static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
1782
1981
|
const int nrows, queue_ptr stream) {
|
|
1783
1982
|
const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
|
|
@@ -2001,8 +2200,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
|
2001
2200
|
const sycl::half alpha_f16 = 1.0f;
|
|
2002
2201
|
const sycl::half beta_f16 = 0.0f;
|
|
2003
2202
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
|
2004
|
-
*stream, oneapi::
|
|
2005
|
-
oneapi::
|
|
2203
|
+
*stream, oneapi::mkl::transpose::trans,
|
|
2204
|
+
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
|
2006
2205
|
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
|
2007
2206
|
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
|
2008
2207
|
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
|
@@ -2045,8 +2244,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
|
2045
2244
|
{
|
|
2046
2245
|
const float alpha = 1.0f;
|
|
2047
2246
|
const float beta = 0.0f;
|
|
2048
|
-
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::
|
|
2049
|
-
|
|
2247
|
+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
|
2248
|
+
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff,
|
|
2050
2249
|
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
|
2051
2250
|
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
|
2052
2251
|
}
|
|
@@ -2127,6 +2326,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
|
|
|
2127
2326
|
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
|
2128
2327
|
}
|
|
2129
2328
|
|
|
2329
|
+
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2330
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
2331
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
2332
|
+
|
|
2333
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
2334
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
2335
|
+
|
|
2336
|
+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
|
2337
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
|
2338
|
+
|
|
2339
|
+
const int64_t ncols = dst->src[0]->ne[0];
|
|
2340
|
+
const int64_t nrows = ggml_nrows(dst->src[0]);
|
|
2341
|
+
|
|
2342
|
+
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
|
2343
|
+
|
|
2344
|
+
main_stream->parallel_for(
|
|
2345
|
+
sycl::range<1>(nrows),
|
|
2346
|
+
[=](sycl::id<1> row) {
|
|
2347
|
+
dst_dd[row] /= ncols;
|
|
2348
|
+
}
|
|
2349
|
+
);
|
|
2350
|
+
}
|
|
2351
|
+
|
|
2352
|
+
|
|
2130
2353
|
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2131
2354
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
2132
2355
|
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
|
@@ -2141,7 +2364,32 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
|
|
|
2141
2364
|
|
|
2142
2365
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
|
2143
2366
|
|
|
2144
|
-
argsort_f32_i32_sycl(src0_dd, (int *)
|
|
2367
|
+
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
|
|
2368
|
+
main_stream, ctx.device);
|
|
2369
|
+
}
|
|
2370
|
+
|
|
2371
|
+
static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2372
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2373
|
+
|
|
2374
|
+
GGML_ASSERT(src0);
|
|
2375
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2376
|
+
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
|
2377
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2378
|
+
|
|
2379
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
2380
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
2381
|
+
|
|
2382
|
+
const float * src0_dd = static_cast<const float *>(src0->data);
|
|
2383
|
+
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
|
2384
|
+
|
|
2385
|
+
const int k = dst->ne[0];
|
|
2386
|
+
const int64_t ncols = src0->ne[0];
|
|
2387
|
+
const int64_t nrows = ggml_nrows(src0);
|
|
2388
|
+
|
|
2389
|
+
GGML_ASSERT(k > 0 && k <= 32);
|
|
2390
|
+
GGML_ASSERT(k <= ncols);
|
|
2391
|
+
|
|
2392
|
+
top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);
|
|
2145
2393
|
}
|
|
2146
2394
|
|
|
2147
2395
|
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
@@ -2176,6 +2424,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
|
|
|
2176
2424
|
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
|
|
2177
2425
|
}
|
|
2178
2426
|
|
|
2427
|
+
static void tri_f32_sycl(
|
|
2428
|
+
const float * src,
|
|
2429
|
+
float * dst,
|
|
2430
|
+
const int64_t ne0,
|
|
2431
|
+
const int64_t ne1,
|
|
2432
|
+
const int64_t ne2,
|
|
2433
|
+
const int64_t ne3,
|
|
2434
|
+
const ggml_tri_type ttype,
|
|
2435
|
+
dpct::queue_ptr main_stream
|
|
2436
|
+
) {
|
|
2437
|
+
const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
|
|
2438
|
+
|
|
2439
|
+
main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
|
|
2440
|
+
const int64_t idx = (int64_t) tid[0];
|
|
2441
|
+
|
|
2442
|
+
const int64_t i0 = idx % ne0;
|
|
2443
|
+
const int64_t t1 = idx / ne0;
|
|
2444
|
+
const int64_t i1 = t1 % ne1;
|
|
2445
|
+
|
|
2446
|
+
bool keep = false;
|
|
2447
|
+
switch (ttype) {
|
|
2448
|
+
case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
|
|
2449
|
+
case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
|
|
2450
|
+
case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
|
|
2451
|
+
case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
|
|
2452
|
+
default: keep = false; break;
|
|
2453
|
+
}
|
|
2454
|
+
|
|
2455
|
+
dst[idx] = keep ? src[idx] : 0.0f;
|
|
2456
|
+
});
|
|
2457
|
+
}
|
|
2458
|
+
|
|
2459
|
+
static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2460
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2461
|
+
GGML_ASSERT(src0);
|
|
2462
|
+
|
|
2463
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2464
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
2465
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2466
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
2467
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
2468
|
+
|
|
2469
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
2470
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
2471
|
+
|
|
2472
|
+
const float * src0_dd = static_cast<const float *>(src0->data);
|
|
2473
|
+
float * dst_dd = static_cast<float *>(dst->data);
|
|
2474
|
+
|
|
2475
|
+
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
|
2476
|
+
|
|
2477
|
+
const int64_t ne0 = src0->ne[0];
|
|
2478
|
+
const int64_t ne1 = src0->ne[1];
|
|
2479
|
+
const int64_t ne2 = src0->ne[2];
|
|
2480
|
+
const int64_t ne3 = src0->ne[3];
|
|
2481
|
+
|
|
2482
|
+
tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
|
|
2483
|
+
}
|
|
2484
|
+
|
|
2485
|
+
|
|
2179
2486
|
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2180
2487
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
2181
2488
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
@@ -2548,6 +2855,10 @@ catch (sycl::exception const &exc) {
|
|
|
2548
2855
|
std::exit(1);
|
|
2549
2856
|
}
|
|
2550
2857
|
|
|
2858
|
+
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2859
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
2860
|
+
ggml_sycl_op_repeat_back(ctx, dst);
|
|
2861
|
+
}
|
|
2551
2862
|
|
|
2552
2863
|
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2553
2864
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
|
@@ -2564,6 +2875,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|
|
2564
2875
|
ggml_sycl_op_rms_norm(ctx, dst);
|
|
2565
2876
|
}
|
|
2566
2877
|
|
|
2878
|
+
static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2879
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
|
2880
|
+
ggml_sycl_op_rms_norm_back(ctx, dst);
|
|
2881
|
+
}
|
|
2882
|
+
|
|
2567
2883
|
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
2568
2884
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
2569
2885
|
ggml_sycl_op_l2_norm(ctx, dst);
|
|
@@ -2729,7 +3045,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2729
3045
|
|
|
2730
3046
|
}
|
|
2731
3047
|
#if GGML_SYCL_DNNL
|
|
2732
|
-
// oneDNN handles strided data and does not need overhead of
|
|
3048
|
+
// oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl
|
|
2733
3049
|
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
|
|
2734
3050
|
src1_f16_alloc.alloc(ne_src1);
|
|
2735
3051
|
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
|
@@ -2738,7 +3054,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2738
3054
|
# else
|
|
2739
3055
|
const int64_t ne_src1 = ggml_nelements(src1);
|
|
2740
3056
|
src1_f16_alloc.alloc(ne_src1);
|
|
2741
|
-
const to_fp16_nc_sycl_t to_fp16_nc_sycl =
|
|
3057
|
+
const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);
|
|
2742
3058
|
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
|
2743
3059
|
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
|
2744
3060
|
#endif
|
|
@@ -2882,8 +3198,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2882
3198
|
const int64_t smb = ne12 == 1 ? s13 : s12;
|
|
2883
3199
|
|
|
2884
3200
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
|
2885
|
-
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::
|
|
2886
|
-
oneapi::
|
|
3201
|
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans,
|
|
3202
|
+
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
2887
3203
|
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
|
|
2888
3204
|
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
|
|
2889
3205
|
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
|
@@ -2907,7 +3223,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2907
3223
|
});
|
|
2908
3224
|
|
|
2909
3225
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
|
2910
|
-
*queue, oneapi::
|
|
3226
|
+
*queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
2911
3227
|
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
|
2912
3228
|
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
|
2913
3229
|
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
|
@@ -2981,19 +3297,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
|
|
2981
3297
|
}
|
|
2982
3298
|
}
|
|
2983
3299
|
|
|
3300
|
+
// Helper functions to unify device memory allocation for both async and sync paths
|
|
3301
|
+
static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
|
|
3302
|
+
bool use_async = g_ggml_sycl_use_async_mem_op;
|
|
3303
|
+
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
|
3304
|
+
if (use_async) {
|
|
3305
|
+
return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
|
|
3306
|
+
}
|
|
3307
|
+
#else
|
|
3308
|
+
// If async allocation extension is not available, use_async should always be false.
|
|
3309
|
+
GGML_ASSERT(!use_async);
|
|
3310
|
+
#endif
|
|
3311
|
+
return sycl::malloc(size, *stream, sycl::usm::alloc::device);
|
|
3312
|
+
}
|
|
3313
|
+
|
|
3314
|
+
static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
|
|
3315
|
+
bool use_async = g_ggml_sycl_use_async_mem_op;
|
|
3316
|
+
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
|
3317
|
+
if (use_async) {
|
|
3318
|
+
syclex::async_free(*stream, ptr);
|
|
3319
|
+
return;
|
|
3320
|
+
}
|
|
3321
|
+
#else
|
|
3322
|
+
// If async allocation extension is not available, use_async should always be false.
|
|
3323
|
+
GGML_ASSERT(!use_async);
|
|
3324
|
+
#endif
|
|
3325
|
+
sycl::free(ptr, *stream);
|
|
3326
|
+
}
|
|
3327
|
+
|
|
2984
3328
|
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
|
|
2985
3329
|
dpct::queue_ptr stream) {
|
|
2986
|
-
|
|
2987
|
-
|
|
2988
|
-
|
|
2989
|
-
|
|
3330
|
+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
|
3331
|
+
|
|
3332
|
+
sycl::event copy_event;
|
|
3333
|
+
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
|
3334
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3335
|
+
copy_event.wait();
|
|
3336
|
+
}
|
|
3337
|
+
|
|
2990
3338
|
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
|
2991
3339
|
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
|
2992
3340
|
int offset_blks = offset / sizeof(block_q4_0);
|
|
2993
3341
|
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
|
|
2994
3342
|
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
|
2995
3343
|
|
|
2996
|
-
stream->parallel_for(
|
|
3344
|
+
auto reorder_event = stream->parallel_for(
|
|
2997
3345
|
size / sizeof(block_q4_0),
|
|
2998
3346
|
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
2999
3347
|
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
|
@@ -3004,9 +3352,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
|
|
|
3004
3352
|
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
|
|
3005
3353
|
}
|
|
3006
3354
|
*(d_ptr + ib) = x[ib].d;
|
|
3007
|
-
})
|
|
3008
|
-
|
|
3009
|
-
|
|
3355
|
+
});
|
|
3356
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3357
|
+
reorder_event.wait_and_throw();
|
|
3358
|
+
}
|
|
3359
|
+
sycl_ext_free(stream, tmp_buf);
|
|
3010
3360
|
}
|
|
3011
3361
|
|
|
3012
3362
|
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
|
@@ -3015,14 +3365,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
|
|
3015
3365
|
|
|
3016
3366
|
const int nblocks = size / sizeof(block_q4_K);
|
|
3017
3367
|
|
|
3018
|
-
|
|
3019
|
-
|
|
3368
|
+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
|
3369
|
+
|
|
3370
|
+
sycl::event copy_event;
|
|
3371
|
+
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
|
3372
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3373
|
+
copy_event.wait();
|
|
3374
|
+
}
|
|
3020
3375
|
|
|
3021
3376
|
auto * qs_ptr = data_device;
|
|
3022
3377
|
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
|
|
3023
3378
|
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
|
|
3024
3379
|
|
|
3025
|
-
stream->parallel_for(nblocks, [=](auto i) {
|
|
3380
|
+
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
|
|
3026
3381
|
const block_q4_K * x = (const block_q4_K *) tmp_buf;
|
|
3027
3382
|
const int ib = i;
|
|
3028
3383
|
|
|
@@ -3035,9 +3390,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
|
|
3035
3390
|
}
|
|
3036
3391
|
|
|
3037
3392
|
dm_ptr[ib] = x[ib].dm;
|
|
3038
|
-
})
|
|
3039
|
-
|
|
3040
|
-
|
|
3393
|
+
});
|
|
3394
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3395
|
+
reorder_event.wait_and_throw();
|
|
3396
|
+
}
|
|
3397
|
+
sycl_ext_free(stream, tmp_buf);
|
|
3041
3398
|
}
|
|
3042
3399
|
|
|
3043
3400
|
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
|
@@ -3046,42 +3403,46 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
|
|
|
3046
3403
|
|
|
3047
3404
|
const int nblocks = size / sizeof(block_q6_K);
|
|
3048
3405
|
|
|
3049
|
-
|
|
3050
|
-
|
|
3406
|
+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
|
3407
|
+
|
|
3408
|
+
sycl::event copy_event;
|
|
3409
|
+
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
|
3410
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3411
|
+
copy_event.wait();
|
|
3412
|
+
}
|
|
3051
3413
|
|
|
3052
3414
|
auto * ql_ptr = data_device;
|
|
3053
3415
|
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
|
|
3054
3416
|
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
|
|
3055
3417
|
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
|
|
3056
3418
|
|
|
3057
|
-
stream
|
|
3058
|
-
|
|
3059
|
-
|
|
3060
|
-
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
|
3061
|
-
const int ib = i;
|
|
3062
|
-
|
|
3063
|
-
const uint8_t * ql = x[ib].ql;
|
|
3064
|
-
const uint8_t * qh = x[ib].qh;
|
|
3065
|
-
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
|
|
3066
|
-
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
|
|
3067
|
-
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
|
|
3419
|
+
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
|
|
3420
|
+
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
|
3421
|
+
const int ib = i;
|
|
3068
3422
|
|
|
3069
|
-
|
|
3070
|
-
|
|
3071
|
-
|
|
3072
|
-
|
|
3073
|
-
|
|
3074
|
-
}
|
|
3423
|
+
const uint8_t * ql = x[ib].ql;
|
|
3424
|
+
const uint8_t * qh = x[ib].qh;
|
|
3425
|
+
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
|
|
3426
|
+
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
|
|
3427
|
+
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
|
|
3075
3428
|
|
|
3076
|
-
|
|
3077
|
-
|
|
3078
|
-
|
|
3429
|
+
for (int j = 0; j < QK_K / 2; ++j) {
|
|
3430
|
+
base_ql_ptr[j] = ql[j];
|
|
3431
|
+
}
|
|
3432
|
+
for (int j = 0; j < QK_K / 4; ++j) {
|
|
3433
|
+
base_qh_ptr[j] = qh[j];
|
|
3434
|
+
}
|
|
3079
3435
|
|
|
3080
|
-
|
|
3081
|
-
|
|
3082
|
-
|
|
3436
|
+
for (int j = 0; j < QK_K / 16; ++j) {
|
|
3437
|
+
base_scales_ptr[j] = x[ib].scales[j];
|
|
3438
|
+
}
|
|
3083
3439
|
|
|
3084
|
-
|
|
3440
|
+
dm_ptr[ib] = x[ib].d;
|
|
3441
|
+
});
|
|
3442
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
3443
|
+
reorder_event.wait_and_throw();
|
|
3444
|
+
}
|
|
3445
|
+
sycl_ext_free(stream, tmp_buf);
|
|
3085
3446
|
}
|
|
3086
3447
|
|
|
3087
3448
|
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
|
@@ -3188,20 +3549,19 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
3188
3549
|
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
|
3189
3550
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
|
3190
3551
|
|
|
3552
|
+
|
|
3191
3553
|
// mmvq and mmq need the __dp4a instruction which is available for gen12+
|
|
3192
|
-
// Workaround in https://github.com/
|
|
3554
|
+
// Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
|
|
3193
3555
|
use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
|
|
3194
3556
|
#ifdef SYCL_USE_XMX
|
|
3195
3557
|
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
|
3196
3558
|
#endif // SYCL_USE_XMX
|
|
3197
3559
|
|
|
3198
|
-
|
|
3199
|
-
//
|
|
3200
|
-
if
|
|
3201
|
-
|
|
3202
|
-
|
|
3203
|
-
// requires disabling DMMV if both conditions are met
|
|
3204
|
-
|| (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
|
|
3560
|
+
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
|
|
3561
|
+
// is enabled takes precedence over DMMV, the current if-else implementation
|
|
3562
|
+
// requires disabling DMMV if both conditions are met
|
|
3563
|
+
if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&
|
|
3564
|
+
ggml_sycl_supports_reorder_mmvq(src0->type)))) {
|
|
3205
3565
|
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
|
3206
3566
|
}
|
|
3207
3567
|
|
|
@@ -3510,6 +3870,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|
|
3510
3870
|
ggml_sycl_op_sum_rows(ctx, dst);
|
|
3511
3871
|
}
|
|
3512
3872
|
|
|
3873
|
+
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
3874
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
3875
|
+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
|
3876
|
+
ggml_sycl_op_mean(ctx, dst);
|
|
3877
|
+
}
|
|
3878
|
+
|
|
3513
3879
|
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
3514
3880
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
3515
3881
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
|
@@ -3561,9 +3927,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3561
3927
|
case GGML_OP_REPEAT:
|
|
3562
3928
|
ggml_sycl_repeat(ctx, dst);
|
|
3563
3929
|
break;
|
|
3930
|
+
case GGML_OP_REPEAT_BACK:
|
|
3931
|
+
ggml_sycl_repeat_back(ctx, dst);
|
|
3932
|
+
break;
|
|
3564
3933
|
case GGML_OP_GET_ROWS:
|
|
3565
3934
|
ggml_sycl_get_rows(ctx, dst);
|
|
3566
3935
|
break;
|
|
3936
|
+
case GGML_OP_SET:
|
|
3937
|
+
ggml_sycl_op_set(ctx, dst);
|
|
3938
|
+
break;
|
|
3567
3939
|
case GGML_OP_SET_ROWS:
|
|
3568
3940
|
ggml_sycl_op_set_rows(ctx, dst);
|
|
3569
3941
|
break;
|
|
@@ -3574,6 +3946,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3574
3946
|
case GGML_OP_ADD1: // TODO: more efficient implementation
|
|
3575
3947
|
ggml_sycl_add(ctx, dst);
|
|
3576
3948
|
break;
|
|
3949
|
+
case GGML_OP_ADD_ID:
|
|
3950
|
+
ggml_sycl_add_id(ctx, dst);
|
|
3951
|
+
break;
|
|
3577
3952
|
case GGML_OP_SUB:
|
|
3578
3953
|
ggml_sycl_sub(ctx, dst);
|
|
3579
3954
|
break;
|
|
@@ -3630,6 +4005,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3630
4005
|
case GGML_UNARY_OP_EXP:
|
|
3631
4006
|
ggml_sycl_exp(ctx, dst);
|
|
3632
4007
|
break;
|
|
4008
|
+
case GGML_UNARY_OP_SOFTPLUS:
|
|
4009
|
+
ggml_sycl_softplus(ctx, dst);
|
|
4010
|
+
break;
|
|
3633
4011
|
case GGML_UNARY_OP_SGN:
|
|
3634
4012
|
ggml_sycl_sgn(ctx, dst);
|
|
3635
4013
|
break;
|
|
@@ -3639,6 +4017,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3639
4017
|
case GGML_UNARY_OP_ELU:
|
|
3640
4018
|
ggml_sycl_elu(ctx, dst);
|
|
3641
4019
|
break;
|
|
4020
|
+
case GGML_UNARY_OP_FLOOR:
|
|
4021
|
+
ggml_sycl_floor(ctx, dst);
|
|
4022
|
+
break;
|
|
4023
|
+
case GGML_UNARY_OP_CEIL:
|
|
4024
|
+
ggml_sycl_ceil(ctx, dst);
|
|
4025
|
+
break;
|
|
4026
|
+
case GGML_UNARY_OP_ROUND:
|
|
4027
|
+
ggml_sycl_round(ctx, dst);
|
|
4028
|
+
break;
|
|
4029
|
+
case GGML_UNARY_OP_TRUNC:
|
|
4030
|
+
ggml_sycl_trunc(ctx, dst);
|
|
4031
|
+
break;
|
|
3642
4032
|
default:
|
|
3643
4033
|
return false;
|
|
3644
4034
|
}
|
|
@@ -3654,6 +4044,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3654
4044
|
case GGML_GLU_OP_SWIGLU:
|
|
3655
4045
|
ggml_sycl_swiglu(ctx, dst);
|
|
3656
4046
|
break;
|
|
4047
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
4048
|
+
ggml_sycl_swiglu_oai(ctx, dst);
|
|
4049
|
+
break;
|
|
3657
4050
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
3658
4051
|
ggml_sycl_geglu_erf(ctx, dst);
|
|
3659
4052
|
break;
|
|
@@ -3673,6 +4066,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3673
4066
|
case GGML_OP_CONCAT:
|
|
3674
4067
|
ggml_sycl_op_concat(ctx, dst);
|
|
3675
4068
|
break;
|
|
4069
|
+
case GGML_OP_PAD_REFLECT_1D:
|
|
4070
|
+
ggml_sycl_op_pad_reflect_1d(ctx,dst);
|
|
4071
|
+
break;
|
|
3676
4072
|
case GGML_OP_UPSCALE:
|
|
3677
4073
|
ggml_sycl_upscale(ctx, dst);
|
|
3678
4074
|
break;
|
|
@@ -3682,6 +4078,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3682
4078
|
case GGML_OP_LEAKY_RELU:
|
|
3683
4079
|
ggml_sycl_leaky_relu(ctx, dst);
|
|
3684
4080
|
break;
|
|
4081
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
4082
|
+
ggml_sycl_rms_norm_back(ctx, dst);
|
|
4083
|
+
break;
|
|
3685
4084
|
case GGML_OP_RMS_NORM:
|
|
3686
4085
|
ggml_sycl_rms_norm(ctx, dst);
|
|
3687
4086
|
break;
|
|
@@ -3735,15 +4134,24 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3735
4134
|
case GGML_OP_TRANSPOSE:
|
|
3736
4135
|
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
|
|
3737
4136
|
break;
|
|
4137
|
+
case GGML_OP_TRI:
|
|
4138
|
+
ggml_sycl_op_tri(ctx, dst);
|
|
4139
|
+
break;
|
|
3738
4140
|
case GGML_OP_DIAG_MASK_INF:
|
|
3739
4141
|
ggml_sycl_diag_mask_inf(ctx, dst);
|
|
3740
4142
|
break;
|
|
3741
4143
|
case GGML_OP_SOFT_MAX:
|
|
3742
4144
|
ggml_sycl_op_soft_max(ctx, dst);
|
|
3743
4145
|
break;
|
|
4146
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
4147
|
+
ggml_sycl_op_soft_max_back(ctx, dst);
|
|
4148
|
+
break;
|
|
3744
4149
|
case GGML_OP_ROPE:
|
|
3745
4150
|
ggml_sycl_rope(ctx, dst);
|
|
3746
4151
|
break;
|
|
4152
|
+
case GGML_OP_ROPE_BACK:
|
|
4153
|
+
ggml_sycl_rope_back(ctx, dst);
|
|
4154
|
+
break;
|
|
3747
4155
|
case GGML_OP_IM2COL:
|
|
3748
4156
|
ggml_sycl_im2col(ctx, dst);
|
|
3749
4157
|
break;
|
|
@@ -3756,9 +4164,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3756
4164
|
case GGML_OP_SUM_ROWS:
|
|
3757
4165
|
ggml_sycl_sum_rows(ctx, dst);
|
|
3758
4166
|
break;
|
|
4167
|
+
case GGML_OP_MEAN:
|
|
4168
|
+
ggml_sycl_mean(ctx, dst);
|
|
4169
|
+
break;
|
|
3759
4170
|
case GGML_OP_ARGSORT:
|
|
3760
4171
|
ggml_sycl_argsort(ctx, dst);
|
|
3761
4172
|
break;
|
|
4173
|
+
case GGML_OP_TOP_K:
|
|
4174
|
+
ggml_sycl_op_top_k(ctx, dst);
|
|
4175
|
+
break;
|
|
3762
4176
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
3763
4177
|
ggml_sycl_op_timestep_embedding(ctx, dst);
|
|
3764
4178
|
break;
|
|
@@ -3771,6 +4185,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3771
4185
|
case GGML_OP_GATED_LINEAR_ATTN:
|
|
3772
4186
|
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
|
3773
4187
|
break;
|
|
4188
|
+
case GGML_OP_GATED_DELTA_NET:
|
|
4189
|
+
ggml_sycl_gated_delta_net(ctx, dst);
|
|
4190
|
+
break;
|
|
4191
|
+
case GGML_OP_SSM_CONV:
|
|
4192
|
+
ggml_sycl_ssm_conv(ctx, dst);
|
|
4193
|
+
break;
|
|
4194
|
+
case GGML_OP_ROLL:
|
|
4195
|
+
ggml_sycl_roll(ctx, dst);
|
|
4196
|
+
break;
|
|
4197
|
+
case GGML_OP_ARANGE:
|
|
4198
|
+
ggml_sycl_arange(ctx, dst);
|
|
4199
|
+
break;
|
|
4200
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
4201
|
+
ggml_sycl_flash_attn_ext(ctx, dst);
|
|
4202
|
+
break;
|
|
3774
4203
|
default:
|
|
3775
4204
|
return false;
|
|
3776
4205
|
}
|
|
@@ -3778,6 +4207,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
|
3778
4207
|
return true;
|
|
3779
4208
|
} catch (sycl::exception & e) {
|
|
3780
4209
|
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
|
4210
|
+
std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
|
|
3781
4211
|
std::exit(1);
|
|
3782
4212
|
}
|
|
3783
4213
|
|
|
@@ -3800,16 +4230,6 @@ void ggml_backend_sycl_get_device_memory(int device, size_t *free,
|
|
|
3800
4230
|
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
|
|
3801
4231
|
ggml_sycl_set_device(device);
|
|
3802
4232
|
|
|
3803
|
-
/*
|
|
3804
|
-
DPCT1009:218: SYCL uses exceptions to report errors and does not use the
|
|
3805
|
-
error codes. The original code was commented out and a warning string was
|
|
3806
|
-
inserted. You need to rewrite this code.
|
|
3807
|
-
*/
|
|
3808
|
-
/*
|
|
3809
|
-
DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
|
|
3810
|
-
device information which may not be supported by all compilers or runtimes.
|
|
3811
|
-
You may need to adjust the code.
|
|
3812
|
-
*/
|
|
3813
4233
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
|
3814
4234
|
dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
|
|
3815
4235
|
}
|
|
@@ -3931,6 +4351,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
|
|
|
3931
4351
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
|
3932
4352
|
continue;
|
|
3933
4353
|
}
|
|
4354
|
+
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
|
4355
|
+
continue;
|
|
4356
|
+
}
|
|
3934
4357
|
#ifndef NDEBUG
|
|
3935
4358
|
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
|
|
3936
4359
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
|
@@ -3972,6 +4395,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
|
|
|
3972
4395
|
GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
|
|
3973
4396
|
ggml_op_name(node_op));
|
|
3974
4397
|
return false;
|
|
4398
|
+
case GGML_OP_MUL_MAT:
|
|
4399
|
+
// We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
|
|
4400
|
+
// as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
|
|
4401
|
+
// in reordering.
|
|
4402
|
+
if (!g_ggml_sycl_use_async_mem_op) {
|
|
4403
|
+
GGML_LOG_INFO(
|
|
4404
|
+
"%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
|
|
4405
|
+
"oneAPI async memory allocation extension "
|
|
4406
|
+
"%s\n",
|
|
4407
|
+
__func__, ggml_op_name(node_op));
|
|
4408
|
+
return false;
|
|
4409
|
+
}
|
|
3975
4410
|
}
|
|
3976
4411
|
}
|
|
3977
4412
|
return true;
|
|
@@ -4096,6 +4531,7 @@ struct ggml_backend_sycl_device_context {
|
|
|
4096
4531
|
int device;
|
|
4097
4532
|
std::string name;
|
|
4098
4533
|
std::string description;
|
|
4534
|
+
int op_offload_min_batch_size;
|
|
4099
4535
|
};
|
|
4100
4536
|
|
|
4101
4537
|
static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
|
|
@@ -4166,6 +4602,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
|
|
|
4166
4602
|
}
|
|
4167
4603
|
|
|
4168
4604
|
static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
4605
|
+
ggml_backend_sycl_device_context *sycl_ctx =
|
|
4606
|
+
(ggml_backend_sycl_device_context *)dev->context;
|
|
4607
|
+
int device = sycl_ctx->device;
|
|
4169
4608
|
switch (op->op) {
|
|
4170
4609
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
4171
4610
|
{
|
|
@@ -4178,21 +4617,27 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4178
4617
|
}
|
|
4179
4618
|
case GGML_OP_UNARY:
|
|
4180
4619
|
switch (ggml_get_unary_op(op)) {
|
|
4620
|
+
case GGML_UNARY_OP_SGN:
|
|
4621
|
+
case GGML_UNARY_OP_ABS:
|
|
4181
4622
|
case GGML_UNARY_OP_NEG:
|
|
4182
4623
|
case GGML_UNARY_OP_STEP:
|
|
4624
|
+
case GGML_UNARY_OP_RELU:
|
|
4625
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
|
4626
|
+
case GGML_UNARY_OP_TANH:
|
|
4183
4627
|
case GGML_UNARY_OP_GELU:
|
|
4184
4628
|
case GGML_UNARY_OP_SILU:
|
|
4185
|
-
case GGML_UNARY_OP_RELU:
|
|
4186
4629
|
case GGML_UNARY_OP_SIGMOID:
|
|
4187
|
-
case GGML_UNARY_OP_HARDSIGMOID:
|
|
4188
4630
|
case GGML_UNARY_OP_HARDSWISH:
|
|
4189
4631
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
4190
4632
|
case GGML_UNARY_OP_GELU_ERF:
|
|
4191
|
-
case GGML_UNARY_OP_TANH:
|
|
4192
4633
|
case GGML_UNARY_OP_EXP:
|
|
4193
|
-
case
|
|
4194
|
-
case GGML_UNARY_OP_ABS:
|
|
4634
|
+
case GGML_UNARY_OP_SOFTPLUS:
|
|
4195
4635
|
case GGML_UNARY_OP_ELU:
|
|
4636
|
+
case GGML_UNARY_OP_CEIL:
|
|
4637
|
+
return true;
|
|
4638
|
+
case GGML_UNARY_OP_FLOOR:
|
|
4639
|
+
case GGML_UNARY_OP_ROUND:
|
|
4640
|
+
case GGML_UNARY_OP_TRUNC:
|
|
4196
4641
|
#if defined (GGML_SYCL_F16)
|
|
4197
4642
|
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
|
|
4198
4643
|
#else
|
|
@@ -4206,6 +4651,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4206
4651
|
case GGML_GLU_OP_REGLU:
|
|
4207
4652
|
case GGML_GLU_OP_GEGLU:
|
|
4208
4653
|
case GGML_GLU_OP_SWIGLU:
|
|
4654
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
4209
4655
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
4210
4656
|
case GGML_GLU_OP_GEGLU_QUICK:
|
|
4211
4657
|
return ggml_is_contiguous_1(op->src[0]);
|
|
@@ -4233,15 +4679,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4233
4679
|
}
|
|
4234
4680
|
}
|
|
4235
4681
|
ggml_type src0_type = op->src[0]->type;
|
|
4236
|
-
if (src0_type == GGML_TYPE_BF16
|
|
4237
|
-
// TODO: support
|
|
4682
|
+
if (src0_type == GGML_TYPE_BF16 ) {
|
|
4683
|
+
// TODO: support GGML_TYPE_BF16
|
|
4238
4684
|
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
|
|
4239
4685
|
return false;
|
|
4240
4686
|
}
|
|
4687
|
+
|
|
4241
4688
|
// TODO: The configuration below needs more work to be supported with oneDNN
|
|
4242
|
-
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
|
|
4243
|
-
|
|
4689
|
+
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
|
|
4690
|
+
a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
|
|
4691
|
+
return false;
|
|
4244
4692
|
}
|
|
4693
|
+
|
|
4245
4694
|
// TODO: This specific configuration can fail with oneDNN and needs more debugging
|
|
4246
4695
|
if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
|
|
4247
4696
|
a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
|
|
@@ -4266,6 +4715,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4266
4715
|
return false;
|
|
4267
4716
|
}
|
|
4268
4717
|
}
|
|
4718
|
+
case GGML_OP_SET:
|
|
4719
|
+
return (op->type == GGML_TYPE_F32) &&
|
|
4720
|
+
(op->src[0] && op->src[1]) &&
|
|
4721
|
+
(op->src[0]->type == GGML_TYPE_F32) &&
|
|
4722
|
+
(op->src[1]->type == GGML_TYPE_F32);
|
|
4723
|
+
|
|
4269
4724
|
case GGML_OP_SET_ROWS:
|
|
4270
4725
|
{
|
|
4271
4726
|
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
|
@@ -4343,11 +4798,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4343
4798
|
}
|
|
4344
4799
|
return false;
|
|
4345
4800
|
}
|
|
4346
|
-
case
|
|
4801
|
+
case GGML_OP_REPEAT_BACK:
|
|
4347
4802
|
{
|
|
4348
4803
|
ggml_type src0_type = op->src[0]->type;
|
|
4349
|
-
return src0_type
|
|
4804
|
+
return src0_type == GGML_TYPE_F32;
|
|
4350
4805
|
}
|
|
4806
|
+
case GGML_OP_CONCAT:
|
|
4351
4807
|
case GGML_OP_DUP:
|
|
4352
4808
|
case GGML_OP_ARGMAX:
|
|
4353
4809
|
case GGML_OP_NONE:
|
|
@@ -4355,15 +4811,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4355
4811
|
case GGML_OP_VIEW:
|
|
4356
4812
|
case GGML_OP_PERMUTE:
|
|
4357
4813
|
case GGML_OP_TRANSPOSE:
|
|
4358
|
-
return true;
|
|
4359
4814
|
case GGML_OP_ADD:
|
|
4360
4815
|
case GGML_OP_ADD1:
|
|
4816
|
+
case GGML_OP_ADD_ID:
|
|
4361
4817
|
case GGML_OP_SUB:
|
|
4362
4818
|
case GGML_OP_COUNT_EQUAL:
|
|
4363
4819
|
case GGML_OP_MUL:
|
|
4364
4820
|
case GGML_OP_DIV:
|
|
4365
4821
|
case GGML_OP_REPEAT:
|
|
4366
4822
|
return true;
|
|
4823
|
+
case GGML_OP_PAD_REFLECT_1D:
|
|
4824
|
+
return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
|
4367
4825
|
case GGML_OP_SQR:
|
|
4368
4826
|
case GGML_OP_SQRT:
|
|
4369
4827
|
case GGML_OP_SIN:
|
|
@@ -4376,50 +4834,81 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4376
4834
|
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
|
4377
4835
|
#endif
|
|
4378
4836
|
case GGML_OP_NORM:
|
|
4379
|
-
return true;
|
|
4380
4837
|
case GGML_OP_L2_NORM:
|
|
4381
4838
|
case GGML_OP_GROUP_NORM:
|
|
4382
|
-
return ggml_is_contiguous(op->src[0]);
|
|
4383
4839
|
case GGML_OP_RMS_NORM:
|
|
4384
|
-
return
|
|
4840
|
+
return true;
|
|
4841
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
4842
|
+
return ggml_is_contiguous(op->src[0]);
|
|
4385
4843
|
case GGML_OP_SCALE:
|
|
4386
4844
|
return true;
|
|
4387
4845
|
case GGML_OP_CONT:
|
|
4388
4846
|
return op->src[0]->type != GGML_TYPE_BF16;
|
|
4389
|
-
case
|
|
4390
|
-
|
|
4391
|
-
|
|
4392
|
-
return
|
|
4393
|
-
|
|
4394
|
-
|
|
4395
|
-
if (op->src[2]) {
|
|
4396
|
-
return false;
|
|
4847
|
+
case GGML_OP_TRI:
|
|
4848
|
+
{
|
|
4849
|
+
const ggml_tensor * src0 = op->src[0];
|
|
4850
|
+
return src0 &&
|
|
4851
|
+
op->type == GGML_TYPE_F32 &&
|
|
4852
|
+
ggml_is_contiguous(src0);
|
|
4397
4853
|
}
|
|
4398
|
-
// TODO: support broadcast
|
|
4399
|
-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
|
4400
|
-
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
|
4401
4854
|
case GGML_OP_DIAG_MASK_INF:
|
|
4855
|
+
return true;
|
|
4856
|
+
case GGML_OP_SOFT_MAX:
|
|
4857
|
+
return true;
|
|
4858
|
+
case GGML_OP_SOFT_MAX_BACK: {
|
|
4859
|
+
float max_bias = 0.0f;
|
|
4860
|
+
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
|
|
4861
|
+
return max_bias == 0.0f;
|
|
4862
|
+
}
|
|
4402
4863
|
case GGML_OP_ROPE:
|
|
4864
|
+
case GGML_OP_ROPE_BACK:
|
|
4403
4865
|
case GGML_OP_IM2COL:
|
|
4404
4866
|
return true;
|
|
4405
4867
|
case GGML_OP_UPSCALE:
|
|
4406
|
-
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
4868
|
+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
|
4407
4869
|
case GGML_OP_SUM:
|
|
4408
4870
|
case GGML_OP_SUM_ROWS:
|
|
4409
|
-
case
|
|
4871
|
+
case GGML_OP_MEAN:
|
|
4410
4872
|
return ggml_is_contiguous(op->src[0]);
|
|
4873
|
+
case GGML_OP_ARGSORT:
|
|
4874
|
+
return op->src[0]->ne[0] * sizeof(int) <=
|
|
4875
|
+
ggml_sycl_info().devices[device].smpbo;
|
|
4876
|
+
case GGML_OP_TOP_K: {
|
|
4877
|
+
const ggml_tensor * src0 = op->src[0];
|
|
4878
|
+
const int k = op->ne[0];
|
|
4879
|
+
return src0 &&
|
|
4880
|
+
op->type == GGML_TYPE_I32 &&
|
|
4881
|
+
src0->type == GGML_TYPE_F32 &&
|
|
4882
|
+
ggml_is_contiguous(src0) &&
|
|
4883
|
+
k > 0 && k <= 32;
|
|
4884
|
+
}
|
|
4411
4885
|
case GGML_OP_POOL_2D:
|
|
4412
|
-
case GGML_OP_ACC:
|
|
4413
4886
|
return true;
|
|
4887
|
+
case GGML_OP_ACC:
|
|
4888
|
+
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
|
4414
4889
|
case GGML_OP_PAD:
|
|
4415
|
-
|
|
4416
|
-
|
|
4890
|
+
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
|
4891
|
+
if (ggml_get_op_params_i32(op, 8) != 0) {
|
|
4892
|
+
return false;
|
|
4893
|
+
}
|
|
4894
|
+
return ggml_is_contiguous(op->src[0]);
|
|
4417
4895
|
case GGML_OP_LEAKY_RELU:
|
|
4418
4896
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
4419
4897
|
case GGML_OP_RWKV_WKV6:
|
|
4420
4898
|
case GGML_OP_RWKV_WKV7:
|
|
4421
4899
|
case GGML_OP_GATED_LINEAR_ATTN:
|
|
4900
|
+
case GGML_OP_GATED_DELTA_NET:
|
|
4422
4901
|
return true;
|
|
4902
|
+
case GGML_OP_SSM_CONV:
|
|
4903
|
+
return op->type == GGML_TYPE_F32 &&
|
|
4904
|
+
op->src[0]->type == GGML_TYPE_F32 &&
|
|
4905
|
+
op->src[1]->type == GGML_TYPE_F32;
|
|
4906
|
+
case GGML_OP_ROLL:
|
|
4907
|
+
return op->type == GGML_TYPE_F32;
|
|
4908
|
+
case GGML_OP_ARANGE:
|
|
4909
|
+
return op->type == GGML_TYPE_F32;
|
|
4910
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
4911
|
+
return ggml_sycl_flash_attn_ext_supported(device, op);
|
|
4423
4912
|
default:
|
|
4424
4913
|
return false;
|
|
4425
4914
|
}
|
|
@@ -4451,9 +4940,8 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
|
|
|
4451
4940
|
}
|
|
4452
4941
|
|
|
4453
4942
|
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
4454
|
-
|
|
4455
|
-
return get_op_batch_size(op) >=
|
|
4456
|
-
GGML_UNUSED(dev);
|
|
4943
|
+
ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
|
|
4944
|
+
return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
|
|
4457
4945
|
}
|
|
4458
4946
|
|
|
4459
4947
|
static ggml_backend_event_t
|
|
@@ -4576,6 +5064,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
|
|
|
4576
5064
|
std::lock_guard<std::mutex> lock(mutex);
|
|
4577
5065
|
if (!initialized) {
|
|
4578
5066
|
ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
|
|
5067
|
+
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
|
|
4579
5068
|
|
|
4580
5069
|
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
|
|
4581
5070
|
ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
|
|
@@ -4589,6 +5078,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
|
|
|
4589
5078
|
prop, dpct::dev_mgr::instance().get_device(i))));
|
|
4590
5079
|
|
|
4591
5080
|
dev_ctx->description = prop.get_name();
|
|
5081
|
+
dev_ctx->op_offload_min_batch_size = min_batch_size;
|
|
4592
5082
|
|
|
4593
5083
|
ggml_backend_dev_t dev = new ggml_backend_device {
|
|
4594
5084
|
/* .iface = */ ggml_backend_sycl_device_interface,
|