whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
#include "rope.hpp"
|
|
2
|
+
#include "convert.hpp"
|
|
2
3
|
#include "ggml-sycl/common.hpp"
|
|
3
4
|
#include "ggml.h"
|
|
4
5
|
|
|
@@ -15,367 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
|
15
16
|
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
|
|
16
17
|
}
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
// Get n-d rotational scaling corrected for extrapolation
|
|
19
|
+
template <bool forward>
|
|
20
|
+
static void rope_yarn(const float theta_extrap, const float freq_scale,
|
|
21
|
+
const rope_corr_dims corr_dims, const int64_t i0,
|
|
22
|
+
const float ext_factor, float mscale, float &cos_theta,
|
|
23
|
+
float &sin_theta) {
|
|
24
24
|
float theta_interp = freq_scale * theta_extrap;
|
|
25
25
|
float theta = theta_interp;
|
|
26
26
|
if (ext_factor != 0.0f) {
|
|
27
|
-
float ramp_mix =
|
|
27
|
+
float ramp_mix =
|
|
28
|
+
rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
|
28
29
|
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
29
30
|
|
|
30
|
-
// Get n-d magnitude scaling corrected for interpolation
|
|
31
31
|
mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
|
|
32
32
|
}
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
cos_theta = sycl::cos(theta) * mscale;
|
|
34
|
+
sin_theta = sycl::sin(theta) * mscale;
|
|
35
|
+
if (!forward) {
|
|
36
|
+
sin_theta *= -1.0f;
|
|
37
|
+
}
|
|
35
38
|
}
|
|
36
39
|
|
|
37
|
-
template <typename T,
|
|
38
|
-
static void rope_norm(const T *
|
|
39
|
-
const
|
|
40
|
-
const
|
|
41
|
-
const
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
40
|
+
template <bool forward, bool has_ff, typename T, typename D>
|
|
41
|
+
static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
|
|
42
|
+
const int ne02, const int s01, const int s02,
|
|
43
|
+
const int s03, const int s1, const int s2, const int s3,
|
|
44
|
+
const int n_dims, const int32_t *pos,
|
|
45
|
+
const float freq_scale, const float ext_factor,
|
|
46
|
+
const float attn_factor, const rope_corr_dims corr_dims,
|
|
47
|
+
const float theta_scale, const float *freq_factors,
|
|
48
|
+
const int64_t *row_indices, const int set_rows_stride) {
|
|
49
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
50
|
+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
51
|
+
item_ct1.get_local_id(1));
|
|
52
|
+
|
|
53
|
+
if (i0 >= ne00) {
|
|
45
54
|
return;
|
|
46
55
|
}
|
|
47
56
|
|
|
48
|
-
const int
|
|
57
|
+
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
58
|
+
item_ct1.get_local_id(2);
|
|
49
59
|
|
|
50
|
-
const
|
|
51
|
-
const
|
|
60
|
+
const uint32_t i3 = row_dst / (ne01 * ne02);
|
|
61
|
+
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
|
62
|
+
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
|
52
63
|
|
|
53
|
-
|
|
54
|
-
const int
|
|
64
|
+
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
|
|
65
|
+
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
|
|
66
|
+
|
|
67
|
+
if (set_rows_stride != 0) {
|
|
68
|
+
idst = i1 * s1 + i0;
|
|
69
|
+
idst += row_indices[i2] * set_rows_stride;
|
|
70
|
+
}
|
|
55
71
|
|
|
72
|
+
const auto &store_coaelsced = [&](float x0, float x1) {
|
|
73
|
+
if constexpr (std::is_same_v<float, D>) {
|
|
74
|
+
sycl::float2 v = sycl::float2(x0, x1);
|
|
75
|
+
ggml_sycl_memcpy_1<8>(dst + idst, &v);
|
|
76
|
+
} else if constexpr (std::is_same_v<sycl::half, D>) {
|
|
77
|
+
sycl::half2 v = sycl::half2(x0, x1);
|
|
78
|
+
ggml_sycl_memcpy_1<4>(dst + idst, &v);
|
|
79
|
+
}
|
|
80
|
+
};
|
|
56
81
|
if (i0 >= n_dims) {
|
|
57
|
-
|
|
82
|
+
store_coaelsced(x[ix + 0], x[ix + 1]);
|
|
58
83
|
return;
|
|
59
84
|
}
|
|
60
85
|
|
|
61
|
-
const float theta_base = pos[
|
|
86
|
+
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
62
87
|
|
|
63
88
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
64
89
|
|
|
65
90
|
float cos_theta;
|
|
66
91
|
float sin_theta;
|
|
67
92
|
|
|
68
|
-
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
|
93
|
+
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
|
94
|
+
ext_factor, attn_factor, cos_theta, sin_theta);
|
|
69
95
|
|
|
70
|
-
const float x0 = x[
|
|
71
|
-
const float x1 = x[
|
|
96
|
+
const float x0 = x[ix + 0];
|
|
97
|
+
const float x1 = x[ix + 1];
|
|
72
98
|
|
|
73
|
-
|
|
74
|
-
|
|
99
|
+
store_coaelsced(x0 * cos_theta - x1 * sin_theta,
|
|
100
|
+
x0 * sin_theta + x1 * cos_theta);
|
|
75
101
|
}
|
|
76
102
|
|
|
77
|
-
template <typename T,
|
|
78
|
-
static void rope_neox(const T *
|
|
79
|
-
const
|
|
80
|
-
const
|
|
81
|
-
const
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
103
|
+
template <bool forward, bool has_ff, typename T, typename D>
|
|
104
|
+
static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
|
|
105
|
+
const int ne02, const int s01, const int s02,
|
|
106
|
+
const int s03, const int s1, const int s2, const int s3,
|
|
107
|
+
const int n_dims, const int32_t *pos,
|
|
108
|
+
const float freq_scale, const float ext_factor,
|
|
109
|
+
const float attn_factor, const rope_corr_dims corr_dims,
|
|
110
|
+
const float theta_scale, const float *freq_factors,
|
|
111
|
+
const int64_t *row_indices, const int set_rows_stride) {
|
|
112
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
113
|
+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
114
|
+
item_ct1.get_local_id(1));
|
|
115
|
+
|
|
116
|
+
if (i0 >= ne00) {
|
|
85
117
|
return;
|
|
86
118
|
}
|
|
87
119
|
|
|
88
|
-
const int
|
|
120
|
+
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
121
|
+
item_ct1.get_local_id(2);
|
|
89
122
|
|
|
90
|
-
const
|
|
91
|
-
const
|
|
123
|
+
const uint32_t i3 = row_dst / (ne01 * ne02);
|
|
124
|
+
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
|
125
|
+
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
|
92
126
|
|
|
93
|
-
|
|
94
|
-
const int
|
|
127
|
+
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
|
128
|
+
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
|
129
|
+
|
|
130
|
+
if (set_rows_stride != 0) {
|
|
131
|
+
idst = i1 * s1 + i0 / 2;
|
|
132
|
+
idst += row_indices[i2] * set_rows_stride;
|
|
133
|
+
}
|
|
95
134
|
|
|
96
135
|
if (i0 >= n_dims) {
|
|
97
|
-
|
|
136
|
+
dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
|
|
137
|
+
dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
|
|
138
|
+
|
|
98
139
|
return;
|
|
99
140
|
}
|
|
100
141
|
|
|
101
|
-
const float theta_base = pos[
|
|
142
|
+
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
102
143
|
|
|
103
144
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
104
145
|
|
|
105
146
|
float cos_theta;
|
|
106
147
|
float sin_theta;
|
|
107
148
|
|
|
108
|
-
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
|
149
|
+
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
|
150
|
+
ext_factor, attn_factor, cos_theta, sin_theta);
|
|
109
151
|
|
|
110
|
-
const float x0 = x[
|
|
111
|
-
const float x1 = x[
|
|
152
|
+
const float x0 = x[ix + 0];
|
|
153
|
+
const float x1 = x[ix + n_dims / 2];
|
|
112
154
|
|
|
113
|
-
dst[
|
|
114
|
-
dst[
|
|
155
|
+
dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
|
|
156
|
+
dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
|
115
157
|
}
|
|
116
158
|
|
|
117
|
-
template <
|
|
118
|
-
static void rope_multi(const T *
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
159
|
+
template <bool forward, bool has_ff, typename T>
|
|
160
|
+
static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
|
|
161
|
+
const int ne02, const int s01, const int s02,
|
|
162
|
+
const int s03, const int s1, const int s2, const int s3,
|
|
163
|
+
const int n_dims, const int32_t *pos,
|
|
164
|
+
const float freq_scale, const float ext_factor,
|
|
165
|
+
const float attn_factor, const rope_corr_dims corr_dims,
|
|
166
|
+
const float theta_scale, const float *freq_factors,
|
|
167
|
+
const mrope_sections sections, const bool is_imrope) {
|
|
168
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
169
|
+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
170
|
+
item_ct1.get_local_id(1));
|
|
171
|
+
|
|
172
|
+
if (i0 >= ne00) {
|
|
126
173
|
return;
|
|
127
174
|
}
|
|
128
|
-
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
|
129
175
|
|
|
130
|
-
const int
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
const
|
|
176
|
+
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
177
|
+
item_ct1.get_local_id(2);
|
|
178
|
+
|
|
179
|
+
const uint32_t i3 = row_dst / (ne01 * ne02);
|
|
180
|
+
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
|
181
|
+
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
|
182
|
+
|
|
183
|
+
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
|
184
|
+
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
|
134
185
|
|
|
135
186
|
if (i0 >= n_dims) {
|
|
136
|
-
|
|
187
|
+
dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
|
|
188
|
+
dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
|
|
189
|
+
|
|
137
190
|
return;
|
|
138
191
|
}
|
|
139
192
|
|
|
140
|
-
const int sect_dims =
|
|
193
|
+
const int sect_dims =
|
|
194
|
+
sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
|
141
195
|
const int sec_w = sections.v[1] + sections.v[0];
|
|
142
196
|
const int sector = (i0 / 2) % sect_dims;
|
|
143
197
|
|
|
144
|
-
|
|
145
198
|
float theta_base = 0.0;
|
|
146
199
|
if (is_imrope) {
|
|
147
|
-
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
|
|
148
|
-
theta_base = pos[
|
|
149
|
-
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
|
|
150
|
-
theta_base = pos[
|
|
151
|
-
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
|
|
152
|
-
theta_base = pos[
|
|
200
|
+
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
|
201
|
+
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
202
|
+
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
|
203
|
+
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
204
|
+
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
|
205
|
+
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
153
206
|
} else {
|
|
154
|
-
theta_base = pos[
|
|
207
|
+
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
155
208
|
}
|
|
156
209
|
} else {
|
|
157
210
|
if (sector < sections.v[0]) {
|
|
158
|
-
theta_base = pos[
|
|
159
|
-
}
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
else if (sector >= sec_w
|
|
164
|
-
theta_base = pos[
|
|
165
|
-
}
|
|
166
|
-
else if (sector >= sec_w + sections.v[2]) {
|
|
167
|
-
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
|
211
|
+
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
212
|
+
} else if (sector >= sections.v[0] && sector < sec_w) {
|
|
213
|
+
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
214
|
+
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
|
215
|
+
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
216
|
+
} else if (sector >= sec_w + sections.v[2]) {
|
|
217
|
+
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
|
|
168
218
|
}
|
|
169
219
|
}
|
|
170
220
|
|
|
171
221
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
172
|
-
float cos_theta;
|
|
173
|
-
float sin_theta;
|
|
174
|
-
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
175
|
-
const float x0 = x[ix + 0];
|
|
176
|
-
const float x1 = x[ix + n_dims/2];
|
|
177
222
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
|
181
|
-
}
|
|
223
|
+
float cos_theta;
|
|
224
|
+
float sin_theta;
|
|
182
225
|
|
|
226
|
+
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
|
227
|
+
ext_factor, attn_factor, cos_theta, sin_theta);
|
|
183
228
|
|
|
229
|
+
const float x0 = x[ix + 0];
|
|
230
|
+
const float x1 = x[ix + n_dims / 2];
|
|
231
|
+
|
|
232
|
+
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
|
233
|
+
dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
|
234
|
+
}
|
|
184
235
|
|
|
185
|
-
template <
|
|
186
|
-
static void rope_vision(const T *
|
|
187
|
-
const
|
|
188
|
-
const
|
|
189
|
-
const
|
|
190
|
-
const
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
236
|
+
template <bool forward, bool has_ff, typename T>
|
|
237
|
+
static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
|
|
238
|
+
const int ne02, const int s01, const int s02,
|
|
239
|
+
const int s03, const int s1, const int s2, const int s3,
|
|
240
|
+
const int n_dims, const int32_t *pos,
|
|
241
|
+
const float freq_scale, const float ext_factor,
|
|
242
|
+
const float attn_factor, const rope_corr_dims corr_dims,
|
|
243
|
+
const float theta_scale, const float *freq_factors,
|
|
244
|
+
const mrope_sections sections) {
|
|
245
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
246
|
+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
247
|
+
item_ct1.get_local_id(1));
|
|
248
|
+
|
|
249
|
+
if (i0 >= ne00) {
|
|
194
250
|
return;
|
|
195
251
|
}
|
|
196
|
-
|
|
197
|
-
const int
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
const
|
|
252
|
+
|
|
253
|
+
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
254
|
+
item_ct1.get_local_id(2);
|
|
255
|
+
|
|
256
|
+
const uint32_t i3 = row_dst / (ne01 * ne02);
|
|
257
|
+
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
|
258
|
+
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
|
259
|
+
|
|
260
|
+
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
|
261
|
+
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
|
201
262
|
|
|
202
263
|
const int sect_dims = sections.v[0] + sections.v[1];
|
|
203
|
-
const int
|
|
264
|
+
const int sec_w = sections.v[1] + sections.v[0];
|
|
265
|
+
const int sector = (i0 / 2) % sect_dims;
|
|
204
266
|
|
|
205
|
-
float theta_base = 0.
|
|
267
|
+
float theta_base = 0.0;
|
|
206
268
|
if (sector < sections.v[0]) {
|
|
207
269
|
const int p = sector;
|
|
208
|
-
theta_base
|
|
209
|
-
} else {
|
|
210
|
-
// Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
|
|
270
|
+
theta_base = pos[i2] * dpct::pow(theta_scale, p);
|
|
271
|
+
} else if (sector >= sections.v[0] && sector < sec_w) {
|
|
211
272
|
const int p = sector - sections.v[0];
|
|
212
|
-
theta_base
|
|
273
|
+
theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
|
|
213
274
|
}
|
|
214
275
|
|
|
215
276
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
216
|
-
|
|
217
|
-
float
|
|
218
|
-
|
|
277
|
+
|
|
278
|
+
float cos_theta;
|
|
279
|
+
float sin_theta;
|
|
280
|
+
|
|
281
|
+
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
|
282
|
+
ext_factor, attn_factor, cos_theta, sin_theta);
|
|
283
|
+
|
|
219
284
|
const float x0 = x[ix + 0];
|
|
220
285
|
const float x1 = x[ix + n_dims];
|
|
221
286
|
|
|
222
|
-
|
|
223
|
-
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
|
287
|
+
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
|
224
288
|
dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
|
|
225
289
|
}
|
|
226
290
|
|
|
227
|
-
template <typename T>
|
|
228
|
-
static void
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
291
|
+
template <bool forward, typename T, typename D>
|
|
292
|
+
static void
|
|
293
|
+
rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
|
|
294
|
+
const int ne02, const int s01, const int s02, const int s03,
|
|
295
|
+
const int s1, const int s2, const int s3, const int n_dims,
|
|
296
|
+
const int nr, const int32_t *pos, const float freq_scale,
|
|
297
|
+
const float freq_base, const float ext_factor,
|
|
298
|
+
const float attn_factor, const rope_corr_dims corr_dims,
|
|
299
|
+
const float *freq_factors, const int64_t *row_indices,
|
|
300
|
+
const int set_rows_stride, dpct::queue_ptr stream) {
|
|
301
|
+
GGML_ASSERT(ne00 % 2 == 0);
|
|
302
|
+
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
303
|
+
const int n_blocks_x =
|
|
304
|
+
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
|
305
|
+
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
|
236
306
|
|
|
237
307
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
238
308
|
|
|
239
|
-
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
240
|
-
|
|
241
309
|
if (freq_factors == nullptr) {
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
310
|
+
stream->parallel_for(
|
|
311
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
312
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
313
|
+
GGML_UNUSED(item_ct1);
|
|
314
|
+
rope_norm<forward, false>(
|
|
315
|
+
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
|
316
|
+
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
317
|
+
theta_scale, freq_factors, row_indices, set_rows_stride);
|
|
318
|
+
});
|
|
251
319
|
} else {
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
320
|
+
stream->parallel_for(
|
|
321
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
322
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
323
|
+
GGML_UNUSED(item_ct1);
|
|
324
|
+
rope_norm<forward, true>(
|
|
325
|
+
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
|
326
|
+
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
327
|
+
theta_scale, freq_factors, row_indices, set_rows_stride);
|
|
328
|
+
});
|
|
261
329
|
}
|
|
262
330
|
}
|
|
263
331
|
|
|
264
|
-
template <typename T>
|
|
265
|
-
static void
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
332
|
+
template <bool forward, typename T, typename D>
|
|
333
|
+
static void
|
|
334
|
+
rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
|
|
335
|
+
const int ne02, const int s01, const int s02, const int s03,
|
|
336
|
+
const int s1, const int s2, const int s3, const int n_dims,
|
|
337
|
+
const int nr, const int32_t *pos, const float freq_scale,
|
|
338
|
+
const float freq_base, const float ext_factor,
|
|
339
|
+
const float attn_factor, const rope_corr_dims corr_dims,
|
|
340
|
+
const float *freq_factors, const int64_t *row_indices,
|
|
341
|
+
const int set_rows_stride, dpct::queue_ptr stream) {
|
|
342
|
+
GGML_ASSERT(ne00 % 2 == 0);
|
|
343
|
+
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
344
|
+
const int n_blocks_x =
|
|
345
|
+
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
|
346
|
+
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
|
273
347
|
|
|
274
348
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
275
349
|
|
|
276
|
-
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
277
|
-
|
|
278
350
|
if (freq_factors == nullptr) {
|
|
279
|
-
stream->parallel_for(
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
351
|
+
stream->parallel_for(
|
|
352
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
353
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
354
|
+
GGML_UNUSED(item_ct1);
|
|
355
|
+
rope_neox<forward, false>(
|
|
356
|
+
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
|
357
|
+
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
358
|
+
theta_scale, freq_factors, row_indices, set_rows_stride);
|
|
359
|
+
});
|
|
283
360
|
} else {
|
|
284
|
-
stream->parallel_for(
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
361
|
+
stream->parallel_for(
|
|
362
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
363
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
364
|
+
GGML_UNUSED(item_ct1);
|
|
365
|
+
rope_neox<forward, true>(
|
|
366
|
+
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
|
367
|
+
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
368
|
+
theta_scale, freq_factors, row_indices, set_rows_stride);
|
|
369
|
+
});
|
|
288
370
|
}
|
|
289
371
|
}
|
|
290
372
|
|
|
291
|
-
template <typename T>
|
|
292
|
-
static void
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
const
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
373
|
+
template <bool forward, typename T>
|
|
374
|
+
static void
|
|
375
|
+
rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
|
|
376
|
+
const int ne02, const int s01, const int s02, const int s03,
|
|
377
|
+
const int s1, const int s2, const int s3, const int n_dims,
|
|
378
|
+
const int nr, const int32_t *pos, const float freq_scale,
|
|
379
|
+
const float freq_base, const float ext_factor,
|
|
380
|
+
const float attn_factor, const rope_corr_dims corr_dims,
|
|
381
|
+
const float *freq_factors, const mrope_sections sections,
|
|
382
|
+
const bool is_imrope, dpct::queue_ptr stream) {
|
|
383
|
+
GGML_ASSERT(ne00 % 2 == 0);
|
|
384
|
+
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
385
|
+
const int n_blocks_x =
|
|
386
|
+
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
|
387
|
+
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
|
388
|
+
|
|
389
|
+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
390
|
+
|
|
309
391
|
if (freq_factors == nullptr) {
|
|
310
|
-
stream->parallel_for(
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
392
|
+
stream->parallel_for(
|
|
393
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
394
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
395
|
+
GGML_UNUSED(item_ct1);
|
|
396
|
+
rope_multi<forward, false, T>(
|
|
397
|
+
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
|
398
|
+
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
399
|
+
theta_scale, freq_factors, sections, is_imrope);
|
|
400
|
+
});
|
|
314
401
|
} else {
|
|
315
|
-
stream->parallel_for(
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
402
|
+
stream->parallel_for(
|
|
403
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
404
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
405
|
+
GGML_UNUSED(item_ct1);
|
|
406
|
+
rope_multi<forward, true, T>(
|
|
407
|
+
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
|
408
|
+
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
409
|
+
theta_scale, freq_factors, sections, is_imrope);
|
|
410
|
+
});
|
|
319
411
|
}
|
|
320
412
|
}
|
|
321
413
|
|
|
414
|
+
template <bool forward, typename T>
|
|
415
|
+
static void
|
|
416
|
+
rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
|
|
417
|
+
const int ne02, const int s01, const int s02, const int s03,
|
|
418
|
+
const int s1, const int s2, const int s3, const int n_dims,
|
|
419
|
+
const int nr, const int32_t *pos, const float freq_scale,
|
|
420
|
+
const float freq_base, const float ext_factor,
|
|
421
|
+
const float attn_factor, const rope_corr_dims corr_dims,
|
|
422
|
+
const float *freq_factors, const mrope_sections sections,
|
|
423
|
+
dpct::queue_ptr stream) {
|
|
424
|
+
GGML_ASSERT(ne00 % 2 == 0);
|
|
425
|
+
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
426
|
+
const int n_blocks_x =
|
|
427
|
+
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
|
428
|
+
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
|
322
429
|
|
|
430
|
+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
323
431
|
|
|
324
|
-
|
|
325
|
-
// rope vision
|
|
326
|
-
template <typename T>
|
|
327
|
-
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
|
328
|
-
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
|
329
|
-
const float freq_scale, const float freq_base, const float ext_factor,
|
|
330
|
-
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
|
331
|
-
const mrope_sections sections, queue_ptr stream) {
|
|
332
|
-
GGML_ASSERT(ne0 % 2 == 0);
|
|
333
|
-
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
334
|
-
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
|
335
|
-
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
|
336
|
-
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
|
337
|
-
|
|
338
|
-
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
|
339
|
-
// Add FP16 capability check if T could be sycl::half
|
|
340
|
-
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
341
|
-
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
342
|
-
}
|
|
343
|
-
// launch kernel
|
|
344
432
|
if (freq_factors == nullptr) {
|
|
345
|
-
stream->parallel_for(
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
433
|
+
stream->parallel_for(
|
|
434
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
435
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
436
|
+
GGML_UNUSED(item_ct1);
|
|
437
|
+
rope_vision<forward, false, T>(
|
|
438
|
+
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
|
439
|
+
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
440
|
+
theta_scale, freq_factors, sections);
|
|
441
|
+
});
|
|
349
442
|
} else {
|
|
350
|
-
stream->parallel_for(
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
443
|
+
stream->parallel_for(
|
|
444
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
445
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
446
|
+
GGML_UNUSED(item_ct1);
|
|
447
|
+
rope_vision<forward, true, T>(
|
|
448
|
+
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
|
449
|
+
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
450
|
+
theta_scale, freq_factors, sections);
|
|
451
|
+
});
|
|
354
452
|
}
|
|
355
453
|
}
|
|
356
454
|
|
|
357
|
-
|
|
455
|
+
template <bool forward>
|
|
456
|
+
void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
|
|
457
|
+
const ggml_tensor *set_rows = nullptr) {
|
|
458
|
+
const ggml_tensor *src0 = dst->src[0];
|
|
459
|
+
const ggml_tensor *src1 = dst->src[1];
|
|
460
|
+
const ggml_tensor *src2 = dst->src[2];
|
|
461
|
+
|
|
462
|
+
const float *src0_d = (const float *)src0->data;
|
|
463
|
+
const float *src1_d = (const float *)src1->data;
|
|
464
|
+
|
|
465
|
+
void *dst_d = dst->data;
|
|
466
|
+
const int64_t *row_indices = nullptr;
|
|
467
|
+
ggml_type dst_type = dst->type;
|
|
468
|
+
int set_rows_stride = 0;
|
|
469
|
+
|
|
470
|
+
if (set_rows != nullptr) {
|
|
471
|
+
GGML_ASSERT(forward);
|
|
472
|
+
dst_d = set_rows->data;
|
|
473
|
+
row_indices = (const int64_t *)set_rows->src[1]->data;
|
|
474
|
+
dst_type = set_rows->type;
|
|
475
|
+
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
|
|
476
|
+
}
|
|
477
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
478
|
+
|
|
479
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
480
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
481
|
+
GGML_ASSERT(src0->type == dst->type ||
|
|
482
|
+
(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
|
|
358
483
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
const int64_t
|
|
363
|
-
const int64_t ne01 = dst->src[0]->ne[1]; // num heads
|
|
364
|
-
const int64_t ne02 = dst->src[0]->ne[2]; // num heads
|
|
365
|
-
const int64_t nr = ggml_nrows(dst->src[0]);
|
|
484
|
+
const int64_t ne00 = src0->ne[0]; // head dims
|
|
485
|
+
const int64_t ne01 = src0->ne[1]; // num heads
|
|
486
|
+
const int64_t ne02 = src0->ne[2]; // num heads
|
|
487
|
+
const int64_t nr = ggml_nrows(src0);
|
|
366
488
|
|
|
367
|
-
const size_t s01 =
|
|
368
|
-
const size_t s02 =
|
|
489
|
+
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
|
490
|
+
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
|
491
|
+
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
|
|
369
492
|
|
|
493
|
+
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
|
|
494
|
+
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
|
|
495
|
+
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
|
|
370
496
|
|
|
371
|
-
|
|
372
|
-
const int
|
|
373
|
-
const int
|
|
374
|
-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
375
|
-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
497
|
+
const int n_dims = ((int32_t *)dst->op_params)[1];
|
|
498
|
+
const int mode = ((int32_t *)dst->op_params)[2];
|
|
499
|
+
const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
|
|
376
500
|
mrope_sections sections;
|
|
377
501
|
|
|
378
|
-
// RoPE alteration for extended context
|
|
379
502
|
float freq_base;
|
|
380
503
|
float freq_scale;
|
|
381
504
|
float ext_factor;
|
|
@@ -383,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
|
383
506
|
float beta_fast;
|
|
384
507
|
float beta_slow;
|
|
385
508
|
|
|
386
|
-
memcpy(&freq_base,
|
|
387
|
-
memcpy(&freq_scale,
|
|
388
|
-
memcpy(&ext_factor,
|
|
389
|
-
memcpy(&attn_factor, (int32_t *)
|
|
390
|
-
memcpy(&beta_fast,
|
|
391
|
-
memcpy(&beta_slow,
|
|
392
|
-
memcpy(§ions.v,
|
|
509
|
+
memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
|
|
510
|
+
memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
|
|
511
|
+
memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
|
|
512
|
+
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
|
|
513
|
+
memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
|
|
514
|
+
memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
|
|
515
|
+
memcpy(§ions.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
|
|
393
516
|
|
|
394
517
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
395
518
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
@@ -397,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
|
397
520
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
398
521
|
|
|
399
522
|
if (is_mrope) {
|
|
400
|
-
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
|
|
523
|
+
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
|
|
524
|
+
sections.v[2] > 0);
|
|
401
525
|
}
|
|
402
526
|
|
|
403
527
|
if (is_vision) {
|
|
404
|
-
GGML_ASSERT(n_dims == ne00/2);
|
|
528
|
+
GGML_ASSERT(n_dims == ne00 / 2);
|
|
405
529
|
}
|
|
406
530
|
|
|
407
|
-
const int32_t *
|
|
531
|
+
const int32_t *pos = (const int32_t *)src1_d;
|
|
408
532
|
|
|
409
|
-
const float *
|
|
410
|
-
if (
|
|
411
|
-
freq_factors = (const float *)
|
|
533
|
+
const float *freq_factors = nullptr;
|
|
534
|
+
if (src2 != nullptr) {
|
|
535
|
+
freq_factors = (const float *)src2->data;
|
|
412
536
|
}
|
|
413
537
|
|
|
414
538
|
rope_corr_dims corr_dims;
|
|
415
|
-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
|
|
416
|
-
|
|
417
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
|
418
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
539
|
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
|
|
540
|
+
beta_slow, corr_dims.v);
|
|
419
541
|
|
|
420
542
|
// compute
|
|
421
543
|
if (is_neox) {
|
|
422
544
|
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
|
|
423
|
-
if (
|
|
424
|
-
rope_neox_sycl
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
545
|
+
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
|
546
|
+
rope_neox_sycl<forward, float, float>(
|
|
547
|
+
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
|
548
|
+
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
|
549
|
+
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
|
550
|
+
set_rows_stride, stream);
|
|
551
|
+
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
|
552
|
+
rope_neox_sycl<forward, float, sycl::half>(
|
|
553
|
+
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
|
|
554
|
+
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
|
555
|
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
|
556
|
+
row_indices, set_rows_stride, stream);
|
|
557
|
+
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
|
558
|
+
rope_neox_sycl<forward, sycl::half, sycl::half>(
|
|
559
|
+
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
|
560
|
+
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
|
561
|
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
|
562
|
+
row_indices, set_rows_stride, stream);
|
|
430
563
|
} else {
|
|
431
|
-
GGML_ABORT("
|
|
564
|
+
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
|
432
565
|
}
|
|
433
566
|
} else if (is_mrope && !is_vision) {
|
|
434
567
|
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
|
|
435
|
-
if (
|
|
436
|
-
rope_multi_sycl((const
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
568
|
+
if (src0->type == GGML_TYPE_F32) {
|
|
569
|
+
rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
|
|
570
|
+
ne00, ne01, ne02, s01, s02, s03, s1, s2,
|
|
571
|
+
s3, n_dims, nr, pos, freq_scale, freq_base,
|
|
572
|
+
ext_factor, attn_factor, corr_dims,
|
|
573
|
+
freq_factors, sections, is_imrope, stream);
|
|
574
|
+
} else if (src0->type == GGML_TYPE_F16) {
|
|
575
|
+
rope_multi_sycl<forward>(
|
|
576
|
+
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
|
577
|
+
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
|
578
|
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
|
579
|
+
sections, is_imrope, stream);
|
|
443
580
|
} else {
|
|
444
581
|
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
|
445
582
|
}
|
|
446
583
|
} else if (is_vision) {
|
|
447
584
|
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
|
448
|
-
if (
|
|
449
|
-
rope_vision_sycl(
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
585
|
+
if (src0->type == GGML_TYPE_F32) {
|
|
586
|
+
rope_vision_sycl<forward>(
|
|
587
|
+
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
|
588
|
+
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
|
589
|
+
ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
|
590
|
+
stream);
|
|
591
|
+
} else if (src0->type == GGML_TYPE_F16) {
|
|
592
|
+
rope_vision_sycl<forward>(
|
|
593
|
+
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
|
594
|
+
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
|
595
|
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
|
596
|
+
sections, stream);
|
|
456
597
|
} else {
|
|
457
598
|
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
|
458
599
|
}
|
|
459
600
|
} else {
|
|
460
601
|
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
|
|
461
|
-
if (
|
|
462
|
-
rope_norm_sycl
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
602
|
+
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
|
603
|
+
rope_norm_sycl<forward, float, float>(
|
|
604
|
+
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
|
605
|
+
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
|
606
|
+
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
|
607
|
+
set_rows_stride, stream);
|
|
608
|
+
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
|
609
|
+
rope_norm_sycl<forward, float, sycl::half>(
|
|
610
|
+
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
|
|
611
|
+
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
|
612
|
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
|
613
|
+
row_indices, set_rows_stride, stream);
|
|
614
|
+
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
|
615
|
+
rope_norm_sycl<forward, sycl::half, sycl::half>(
|
|
616
|
+
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
|
617
|
+
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
|
618
|
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
|
619
|
+
row_indices, set_rows_stride, stream);
|
|
468
620
|
} else {
|
|
469
|
-
GGML_ABORT("
|
|
621
|
+
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
|
470
622
|
}
|
|
471
623
|
}
|
|
472
624
|
}
|
|
473
625
|
|
|
474
|
-
void ggml_sycl_rope(ggml_backend_sycl_context &
|
|
626
|
+
void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
|
475
627
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
|
|
476
|
-
|
|
628
|
+
|
|
629
|
+
ggml_sycl_op_rope_impl<true>(ctx, dst);
|
|
477
630
|
}
|
|
478
631
|
|
|
632
|
+
void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
|
633
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
|
|
634
|
+
ggml_sycl_op_rope_impl<false>(ctx, dst);
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
|
|
638
|
+
ggml_tensor *set_rows) {
|
|
639
|
+
scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
|
|
640
|
+
ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
|
|
641
|
+
}
|