whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -0,0 +1,1179 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <sycl/sycl.hpp>
|
|
4
|
+
#include "dpct/helper.hpp"
|
|
5
|
+
#include "common.hpp"
|
|
6
|
+
#include "convert.hpp"
|
|
7
|
+
#include "vecdotq.hpp"
|
|
8
|
+
|
|
9
|
+
#include "ggml.h"
|
|
10
|
+
|
|
11
|
+
#include <cstdint>
|
|
12
|
+
#include <cmath>
|
|
13
|
+
#include <float.h>
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
#define FATTN_KQ_STRIDE 256
|
|
17
|
+
#define HALF_MAX_HALF sycl::half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
|
18
|
+
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
|
19
|
+
#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
|
|
20
|
+
|
|
21
|
+
typedef void (*fattn_kernel_t)(
|
|
22
|
+
const char* Q,
|
|
23
|
+
const char* K,
|
|
24
|
+
const char* V,
|
|
25
|
+
const char* mask,
|
|
26
|
+
const char* sinks,
|
|
27
|
+
const int* KV_max,
|
|
28
|
+
float* dst,
|
|
29
|
+
sycl::float2* dst_meta,
|
|
30
|
+
const float scale,
|
|
31
|
+
const float max_bias,
|
|
32
|
+
const float m0,
|
|
33
|
+
const float m1,
|
|
34
|
+
const uint32_t n_head_log2,
|
|
35
|
+
const float logit_softcap,
|
|
36
|
+
const int32_t ne00,
|
|
37
|
+
const sycl::uint3 ne01,
|
|
38
|
+
const int32_t ne02,
|
|
39
|
+
const int32_t ne03,
|
|
40
|
+
const int32_t nb01,
|
|
41
|
+
const int32_t nb02,
|
|
42
|
+
const int32_t nb03,
|
|
43
|
+
const int32_t ne10,
|
|
44
|
+
const int32_t ne11,
|
|
45
|
+
const int32_t ne12,
|
|
46
|
+
const int32_t ne13,
|
|
47
|
+
const int32_t nb11,
|
|
48
|
+
const int32_t nb12,
|
|
49
|
+
const int64_t nb13,
|
|
50
|
+
const int32_t nb21,
|
|
51
|
+
const int32_t nb22,
|
|
52
|
+
const int64_t nb23,
|
|
53
|
+
const int32_t ne31,
|
|
54
|
+
const int32_t ne32,
|
|
55
|
+
const int32_t ne33,
|
|
56
|
+
const int32_t nb31,
|
|
57
|
+
const int32_t nb32,
|
|
58
|
+
const int64_t nb33);
|
|
59
|
+
|
|
60
|
+
typedef float (*vec_dot_KQ_t)(
|
|
61
|
+
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
|
62
|
+
|
|
63
|
+
template <int D, int nthreads>
|
|
64
|
+
static __dpct_inline__ float vec_dot_fattn_vec_KQ_f16(const char * __restrict__ K_c,
|
|
65
|
+
const void * __restrict__ Q_v,
|
|
66
|
+
const int * __restrict__ Q_q8,
|
|
67
|
+
const void * __restrict__ Q_ds_v) {
|
|
68
|
+
const sycl::half2 * K_h2 = (const sycl::half2 *) K_c;
|
|
69
|
+
GGML_UNUSED(Q_q8);
|
|
70
|
+
GGML_UNUSED(Q_ds_v);
|
|
71
|
+
|
|
72
|
+
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
|
|
73
|
+
constexpr int cpy_ne = cpy_nb / 4;
|
|
74
|
+
|
|
75
|
+
float sum = 0.0f;
|
|
76
|
+
|
|
77
|
+
#pragma unroll
|
|
78
|
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
|
|
79
|
+
sycl::half2 tmp[cpy_ne];
|
|
80
|
+
ggml_sycl_memcpy_1<sizeof(tmp)>(
|
|
81
|
+
tmp,
|
|
82
|
+
K_h2 + k_KQ_0 + (sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2) % nthreads) * cpy_ne);
|
|
83
|
+
#pragma unroll
|
|
84
|
+
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
|
85
|
+
#ifdef GGML_SYCL_F16
|
|
86
|
+
ggml_sycl_mad(sum, tmp[k_KQ_1] , ((const sycl::half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
|
87
|
+
#else
|
|
88
|
+
ggml_sycl_mad(sum, __half22float2(tmp[k_KQ_1]), ((const sycl::float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
|
89
|
+
#endif // GGML_SYCL_F16
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
return sum;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
template <int D, int nthreads, int warp_size>
|
|
97
|
+
static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_0(const char * __restrict__ K_c,
|
|
98
|
+
const void * __restrict__ Q_v,
|
|
99
|
+
const int * __restrict__ Q_q8,
|
|
100
|
+
const void * __restrict__ Q_ds_v) {
|
|
101
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
102
|
+
|
|
103
|
+
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
|
104
|
+
GGML_UNUSED(Q_v);
|
|
105
|
+
|
|
106
|
+
float sum = 0.0f;
|
|
107
|
+
|
|
108
|
+
#pragma unroll
|
|
109
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
110
|
+
const int k_KQ =
|
|
111
|
+
k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
|
|
112
|
+
|
|
113
|
+
const int ib = k_KQ / QI8_1;
|
|
114
|
+
const int iqs4 = k_KQ % QI4_0;
|
|
115
|
+
const int shift = k_KQ & (QI8_1/2);
|
|
116
|
+
|
|
117
|
+
int v;
|
|
118
|
+
ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
|
|
119
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
|
120
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
|
121
|
+
|
|
122
|
+
const int sumi = ggml_sycl_dp4a(v, u, 0);
|
|
123
|
+
|
|
124
|
+
const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
|
|
125
|
+
sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x() - (8/QI8_1)*Q_ds.y());
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
return sum;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
template <int D, int nthreads , int warp_size>
|
|
132
|
+
static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_1(const char * __restrict__ K_c,
|
|
133
|
+
const void * __restrict__ Q_v,
|
|
134
|
+
const int * __restrict__ Q_q8,
|
|
135
|
+
const void * __restrict__ Q_ds_v) {
|
|
136
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
137
|
+
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
|
138
|
+
GGML_UNUSED(Q_v);
|
|
139
|
+
|
|
140
|
+
float sum = 0.0f;
|
|
141
|
+
|
|
142
|
+
#pragma unroll
|
|
143
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
144
|
+
const int k_KQ =
|
|
145
|
+
k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
|
|
146
|
+
|
|
147
|
+
const int ib = k_KQ / QI8_1;
|
|
148
|
+
const int iqs4 = k_KQ % QI4_1;
|
|
149
|
+
const int shift = k_KQ & (QI8_1/2);
|
|
150
|
+
|
|
151
|
+
int v;
|
|
152
|
+
ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
|
|
153
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
|
154
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
|
155
|
+
|
|
156
|
+
const int sumi = ggml_sycl_dp4a(v, u, 0);
|
|
157
|
+
|
|
158
|
+
const sycl::float2 K_dm = (K_q4_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
|
|
159
|
+
const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
|
|
160
|
+
|
|
161
|
+
sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
return sum;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
template <int D, int nthreads, int warp_size>
|
|
168
|
+
static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_0(const char * __restrict__ K_c,
|
|
169
|
+
const void * __restrict__ Q_v,
|
|
170
|
+
const int * __restrict__ Q_q8,
|
|
171
|
+
const void * __restrict__ Q_ds_v) {
|
|
172
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
173
|
+
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
|
174
|
+
GGML_UNUSED(Q_v);
|
|
175
|
+
|
|
176
|
+
float sum = 0.0f;
|
|
177
|
+
|
|
178
|
+
#pragma unroll
|
|
179
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
180
|
+
const int k_KQ =
|
|
181
|
+
k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
|
|
182
|
+
|
|
183
|
+
const int ib = k_KQ / QI8_1;
|
|
184
|
+
const int iqs4 = k_KQ % QI5_0;
|
|
185
|
+
const int iqs8 = k_KQ % QI8_1;
|
|
186
|
+
const int shift = k_KQ & (QI8_1/2);
|
|
187
|
+
|
|
188
|
+
int v;
|
|
189
|
+
ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
|
|
190
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
|
191
|
+
|
|
192
|
+
{
|
|
193
|
+
int vh;
|
|
194
|
+
ggml_sycl_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
|
|
195
|
+
vh >>= iqs8 * QI5_0;
|
|
196
|
+
|
|
197
|
+
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
|
198
|
+
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
|
199
|
+
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
200
|
+
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
|
204
|
+
|
|
205
|
+
const int sumi = ggml_sycl_dp4a(v, u, 0);
|
|
206
|
+
|
|
207
|
+
const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
|
|
208
|
+
|
|
209
|
+
sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x() - (16/QI8_1)*Q_ds.y());
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
return sum;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
template <int D, int nthreads, int warp_size>
|
|
216
|
+
static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_1(const char * __restrict__ K_c,
|
|
217
|
+
const void * __restrict__ Q_v,
|
|
218
|
+
const int * __restrict__ Q_q8,
|
|
219
|
+
const void * __restrict__ Q_ds_v) {
|
|
220
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
221
|
+
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
|
222
|
+
GGML_UNUSED(Q_v);
|
|
223
|
+
|
|
224
|
+
float sum = 0.0f;
|
|
225
|
+
|
|
226
|
+
#pragma unroll
|
|
227
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
228
|
+
const int k_KQ =
|
|
229
|
+
k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
|
|
230
|
+
|
|
231
|
+
const int ib = k_KQ / QI8_1;
|
|
232
|
+
const int iqs4 = k_KQ % QI5_1;
|
|
233
|
+
const int iqs8 = k_KQ % QI8_1;
|
|
234
|
+
const int shift = k_KQ & (QI8_1/2);
|
|
235
|
+
|
|
236
|
+
int v;
|
|
237
|
+
ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
|
|
238
|
+
v = (v >> shift) & 0x0F0F0F0F;
|
|
239
|
+
|
|
240
|
+
{
|
|
241
|
+
int vh;
|
|
242
|
+
ggml_sycl_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
|
|
243
|
+
vh >>= iqs8 * QI5_0;
|
|
244
|
+
|
|
245
|
+
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
|
246
|
+
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
|
247
|
+
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
248
|
+
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
const int u = Q_q8[k_KQ_0/nthreads];
|
|
252
|
+
|
|
253
|
+
const int sumi = ggml_sycl_dp4a(v, u, 0);
|
|
254
|
+
|
|
255
|
+
const sycl::float2 K_dm = (K_q5_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
|
|
256
|
+
const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
|
|
257
|
+
|
|
258
|
+
sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
return sum;
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
template <int D, int nthreads, int warp_size>
|
|
265
|
+
static __dpct_inline__ float vec_dot_fattn_vec_KQ_q8_0(const char * __restrict__ K_c,
|
|
266
|
+
const void * __restrict__ Q_v,
|
|
267
|
+
const int * __restrict__ Q_q8,
|
|
268
|
+
const void * __restrict__ Q_ds_v) {
|
|
269
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
270
|
+
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
|
271
|
+
GGML_UNUSED(Q_v);
|
|
272
|
+
|
|
273
|
+
float sum = 0.0f;
|
|
274
|
+
|
|
275
|
+
#pragma unroll
|
|
276
|
+
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
|
277
|
+
const int k_KQ =
|
|
278
|
+
k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
|
|
279
|
+
|
|
280
|
+
const int ib = k_KQ / QI8_0;
|
|
281
|
+
const int iqs = k_KQ % QI8_0;
|
|
282
|
+
|
|
283
|
+
int v;
|
|
284
|
+
ggml_sycl_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
|
|
285
|
+
|
|
286
|
+
const sycl::float2 * Q_ds = (const sycl::float2 *) Q_ds_v;
|
|
287
|
+
const float Q_d = Q_ds[k_KQ_0 / nthreads].x();
|
|
288
|
+
|
|
289
|
+
sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
return sum;
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
template <typename Tds, int ni, int warp_size>
|
|
296
|
+
static __dpct_inline__ void quantize_q8_1_to_shared(const float * __restrict__ x,
|
|
297
|
+
const float scale,
|
|
298
|
+
int * __restrict__ yq32,
|
|
299
|
+
void * __restrict__ yds) {
|
|
300
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
301
|
+
|
|
302
|
+
float vals[sizeof(int)] = { 0.0f };
|
|
303
|
+
#pragma unroll
|
|
304
|
+
for (int l = 0; l < int(sizeof(int)); ++l) {
|
|
305
|
+
vals[l] =
|
|
306
|
+
(ni == warp_size || item_ct1.get_local_id(2) < ni) ? scale * x[4 * item_ct1.get_local_id(2) + l] : 0.0f;
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
float amax = sycl::fabs(vals[0]);
|
|
310
|
+
float sum = vals[0];
|
|
311
|
+
#pragma unroll
|
|
312
|
+
for (int l = 1; l < int(sizeof(int)); ++l) {
|
|
313
|
+
amax = sycl::fmax(amax, sycl::fabs(vals[l]));
|
|
314
|
+
sum += vals[l];
|
|
315
|
+
}
|
|
316
|
+
#pragma unroll
|
|
317
|
+
for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
|
|
318
|
+
amax = sycl::fmax(
|
|
319
|
+
amax, dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), amax, mask));
|
|
320
|
+
sum += dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), sum, mask);
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
const float d = amax / 127;
|
|
324
|
+
int q32 = 0;
|
|
325
|
+
int8_t * q8 = (int8_t *) &q32;
|
|
326
|
+
|
|
327
|
+
if (d != 0.0f) {
|
|
328
|
+
#pragma unroll
|
|
329
|
+
for (int l = 0; l < int(sizeof(int)); ++l) {
|
|
330
|
+
q8[l] = sycl::round(vals[l] / d);
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
yq32[item_ct1.get_local_id(2)] = q32;
|
|
335
|
+
if (item_ct1.get_local_id(2) % QI8_1 == 0 && (ni == warp_size || item_ct1.get_local_id(2) < ni)) {
|
|
336
|
+
if (std::is_same<Tds, sycl::half2>::value) {
|
|
337
|
+
((sycl::half2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_half2(d, sum);
|
|
338
|
+
} else {
|
|
339
|
+
((sycl::float2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_float2(d, sum);
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
|
|
345
|
+
|
|
346
|
+
template <typename T, int ne>
|
|
347
|
+
static __dpct_inline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
348
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
349
|
+
ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(dst, (const sycl::half *) vx + i0);
|
|
350
|
+
} else if constexpr (std::is_same_v<T, float>) {
|
|
351
|
+
static_assert(ne % 2 == 0, "bad ne");
|
|
352
|
+
sycl::half2 tmp[ne / 2];
|
|
353
|
+
ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(tmp, (const sycl::half *) vx + i0);
|
|
354
|
+
sycl::float2 * dst_f2 = (sycl::float2 *) dst;
|
|
355
|
+
#pragma unroll
|
|
356
|
+
for (int l = 0; l < ne/2; ++l) {
|
|
357
|
+
dst_f2[l] = tmp[l].template convert<float, sycl::rounding_mode::automatic>();
|
|
358
|
+
}
|
|
359
|
+
} else {
|
|
360
|
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
template <typename T, int ne>
|
|
365
|
+
static __dpct_inline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
366
|
+
const block_q4_0 * x = (const block_q4_0 *) vx;
|
|
367
|
+
|
|
368
|
+
const int64_t ib = i0 / QK4_0;
|
|
369
|
+
const int iqs = i0 % (QK4_0/2);
|
|
370
|
+
const int shift = (i0 % QK4_0) / (QK4_0/2);
|
|
371
|
+
|
|
372
|
+
int q;
|
|
373
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
|
374
|
+
ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
|
|
375
|
+
q >>= 4*shift;
|
|
376
|
+
q &= 0x0F0F0F0F;
|
|
377
|
+
q = dpct::vectorized_binary<sycl::char4>(q, 0x08080808, dpct::sub_sat());
|
|
378
|
+
|
|
379
|
+
const int8_t * q8 = (const int8_t *) &q;
|
|
380
|
+
|
|
381
|
+
#ifdef GGML_SYCL_F16
|
|
382
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
383
|
+
const sycl::half2 d = sycl::half2(x[ib].d);
|
|
384
|
+
|
|
385
|
+
#pragma unroll
|
|
386
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
387
|
+
((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
|
|
388
|
+
}
|
|
389
|
+
} else
|
|
390
|
+
#endif // GGML_SYCL_F16
|
|
391
|
+
if constexpr (std::is_same_v<T, float>) {
|
|
392
|
+
const float d = x[ib].d;
|
|
393
|
+
|
|
394
|
+
#pragma unroll
|
|
395
|
+
for (int l = 0; l < ne; ++l) {
|
|
396
|
+
((float *) dst)[l] = d * q8[l];
|
|
397
|
+
}
|
|
398
|
+
} else {
|
|
399
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
|
400
|
+
}
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
template <typename T, int ne>
|
|
404
|
+
static __dpct_inline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
405
|
+
const block_q4_1 * x = (const block_q4_1 *) vx;
|
|
406
|
+
|
|
407
|
+
const int64_t ib = i0 / QK4_1;
|
|
408
|
+
const int iqs = i0 % (QK4_1/2);
|
|
409
|
+
const int shift = (i0 % QK4_1) / (QK4_1/2);
|
|
410
|
+
|
|
411
|
+
int q;
|
|
412
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
|
413
|
+
ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);
|
|
414
|
+
q >>= 4*shift;
|
|
415
|
+
q &= 0x0F0F0F0F;
|
|
416
|
+
|
|
417
|
+
const int8_t * q8 = (const int8_t *) &q;
|
|
418
|
+
|
|
419
|
+
#ifdef GGML_SYCL_F16
|
|
420
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
421
|
+
const sycl::half2 dm = x[ib].dm;
|
|
422
|
+
const sycl::half2 d = sycl::half2(dm[0]);
|
|
423
|
+
const sycl::half2 m = sycl::half2(dm[1]);
|
|
424
|
+
|
|
425
|
+
#pragma unroll
|
|
426
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
427
|
+
((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
|
|
428
|
+
}
|
|
429
|
+
} else
|
|
430
|
+
#endif // GGML_SYCL_F16
|
|
431
|
+
if constexpr (std::is_same_v<T, float>) {
|
|
432
|
+
const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
|
|
433
|
+
|
|
434
|
+
#pragma unroll
|
|
435
|
+
for (int l = 0; l < ne; ++l) {
|
|
436
|
+
((float *) dst)[l] = dm.x() * q8[l] + dm.y();
|
|
437
|
+
}
|
|
438
|
+
} else {
|
|
439
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
template <typename T, int ne>
|
|
444
|
+
static __dpct_inline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
445
|
+
const block_q5_0 * x = (const block_q5_0 *) vx;
|
|
446
|
+
|
|
447
|
+
const int64_t ib = i0 / QK5_0;
|
|
448
|
+
const int idq = i0 % QK5_0;
|
|
449
|
+
const int iqs = i0 % (QK5_0/2);
|
|
450
|
+
const int shift = (i0 % QK5_0) / (QK5_0/2);
|
|
451
|
+
|
|
452
|
+
int q;
|
|
453
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
|
454
|
+
ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
|
|
455
|
+
q >>= 4*shift;
|
|
456
|
+
q &= 0x0F0F0F0F;
|
|
457
|
+
|
|
458
|
+
{
|
|
459
|
+
int qh;
|
|
460
|
+
ggml_sycl_memcpy_1<ne, 2>(&qh, x[ib].qh);
|
|
461
|
+
#pragma unroll
|
|
462
|
+
for (int l = 0; l < ne; ++l) {
|
|
463
|
+
q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
q = dpct::vectorized_binary<sycl::char4>(q, 0x10101010, dpct::sub_sat());
|
|
468
|
+
|
|
469
|
+
const int8_t * q8 = (const int8_t *) &q;
|
|
470
|
+
|
|
471
|
+
#ifdef GGML_SYCL_F16
|
|
472
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
473
|
+
const sycl::half2 d = sycl::half2(x[ib].d);
|
|
474
|
+
|
|
475
|
+
#pragma unroll
|
|
476
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
477
|
+
((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
|
|
478
|
+
}
|
|
479
|
+
} else
|
|
480
|
+
#endif // GGML_SYCL_F16
|
|
481
|
+
if constexpr (std::is_same_v<T, float>) {
|
|
482
|
+
const float d = x[ib].d;
|
|
483
|
+
|
|
484
|
+
#pragma unroll
|
|
485
|
+
for (int l = 0; l < ne; ++l) {
|
|
486
|
+
((float *) dst)[l] = d * q8[l];
|
|
487
|
+
}
|
|
488
|
+
} else {
|
|
489
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
|
490
|
+
}
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
template <typename T, int ne>
|
|
494
|
+
static __dpct_inline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
495
|
+
const block_q5_1 * x = (const block_q5_1 *) vx;
|
|
496
|
+
|
|
497
|
+
const int64_t ib = i0 / QK5_1;
|
|
498
|
+
const int idq = i0 % QK5_1;
|
|
499
|
+
const int iqs = i0 % (QK5_1/2);
|
|
500
|
+
const int shift = (i0 % QK5_1) / (QK5_1/2);
|
|
501
|
+
|
|
502
|
+
int q;
|
|
503
|
+
static_assert(ne == 2 || ne == 4, "bad ne");
|
|
504
|
+
ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);
|
|
505
|
+
q >>= 4*shift;
|
|
506
|
+
q &= 0x0F0F0F0F;
|
|
507
|
+
|
|
508
|
+
{
|
|
509
|
+
int qh;
|
|
510
|
+
ggml_sycl_memcpy_1<ne>(&qh, x[ib].qh);
|
|
511
|
+
#pragma unroll
|
|
512
|
+
for (int l = 0; l < ne; ++l) {
|
|
513
|
+
q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
|
|
514
|
+
}
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
const int8_t * q8 = (const int8_t *) &q;
|
|
518
|
+
|
|
519
|
+
#ifdef GGML_SYCL_F16
|
|
520
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
521
|
+
const sycl::half2 dm = x[ib].dm;
|
|
522
|
+
const sycl::half2 d = sycl::half2(dm[0]);
|
|
523
|
+
const sycl::half2 m = sycl::half2(dm[1]);
|
|
524
|
+
|
|
525
|
+
#pragma unroll
|
|
526
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
527
|
+
((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
|
|
528
|
+
}
|
|
529
|
+
} else
|
|
530
|
+
#endif // GGML_SYCL_F16
|
|
531
|
+
if constexpr (std::is_same_v<T, float>) {
|
|
532
|
+
const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
|
|
533
|
+
|
|
534
|
+
#pragma unroll
|
|
535
|
+
for (int l = 0; l < ne; ++l) {
|
|
536
|
+
((float *) dst)[l] = dm.x() * q8[l] + dm.y();
|
|
537
|
+
}
|
|
538
|
+
} else {
|
|
539
|
+
static_assert(std::is_same_v<T, void>, "bad type");
|
|
540
|
+
}
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
template <typename T, int ne>
|
|
544
|
+
static __dpct_inline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
|
545
|
+
const block_q8_0 * x = (const block_q8_0 *) vx;
|
|
546
|
+
|
|
547
|
+
const int64_t ib = i0 / QK8_0;
|
|
548
|
+
const int iqs = i0 % QK8_0;
|
|
549
|
+
|
|
550
|
+
static_assert(ne % 2 == 0, "bad ne");
|
|
551
|
+
int8_t qs[ne];
|
|
552
|
+
ggml_sycl_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
|
|
553
|
+
|
|
554
|
+
#ifdef GGML_SYCL_F16
|
|
555
|
+
if constexpr (std::is_same<T, sycl::half>::value) {
|
|
556
|
+
const sycl::half2 d = sycl::half2(x[ib].d);
|
|
557
|
+
|
|
558
|
+
#pragma unroll
|
|
559
|
+
for (int l0 = 0; l0 < ne; l0 += 2) {
|
|
560
|
+
((sycl::half2 *) dst)[l0 / 2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
|
|
561
|
+
}
|
|
562
|
+
} else
|
|
563
|
+
#endif // GGML_SYCL_F16
|
|
564
|
+
if constexpr (std::is_same<T, float>::value) {
|
|
565
|
+
const float d = x[ib].d;
|
|
566
|
+
|
|
567
|
+
#pragma unroll
|
|
568
|
+
for (int l = 0; l < ne; ++l) {
|
|
569
|
+
((float *) dst)[l] = d * qs[l];
|
|
570
|
+
}
|
|
571
|
+
} else {
|
|
572
|
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
template <int type_K, int D, int nthreads, int warp_size>
|
|
577
|
+
constexpr vec_dot_KQ_t get_vec_dot_KQ() {
|
|
578
|
+
if constexpr (type_K == GGML_TYPE_F16) {
|
|
579
|
+
return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
|
|
580
|
+
} else if constexpr (type_K == GGML_TYPE_Q4_0) {
|
|
581
|
+
return vec_dot_fattn_vec_KQ_q4_0<D, nthreads, warp_size>;
|
|
582
|
+
} else if constexpr (type_K == GGML_TYPE_Q4_1) {
|
|
583
|
+
return vec_dot_fattn_vec_KQ_q4_1<D, nthreads, warp_size>;
|
|
584
|
+
} else if constexpr (type_K == GGML_TYPE_Q5_0) {
|
|
585
|
+
return vec_dot_fattn_vec_KQ_q5_0<D, nthreads, warp_size>;
|
|
586
|
+
} else if constexpr (type_K == GGML_TYPE_Q5_1) {
|
|
587
|
+
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads, warp_size>;
|
|
588
|
+
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
|
|
589
|
+
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads, warp_size>;
|
|
590
|
+
} else {
|
|
591
|
+
static_assert(type_K == -1, "bad type");
|
|
592
|
+
return nullptr;
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
template <int type_V, typename T, int ne>
|
|
597
|
+
constexpr dequantize_V_t get_dequantize_V() {
|
|
598
|
+
if constexpr (type_V == GGML_TYPE_F16) {
|
|
599
|
+
return dequantize_V_f16<T, ne>;
|
|
600
|
+
} else if constexpr (type_V == GGML_TYPE_Q4_0) {
|
|
601
|
+
return dequantize_V_q4_0<T, ne>;
|
|
602
|
+
} else if constexpr (type_V == GGML_TYPE_Q4_1) {
|
|
603
|
+
return dequantize_V_q4_1<T, ne>;
|
|
604
|
+
} else if constexpr (type_V == GGML_TYPE_Q5_0) {
|
|
605
|
+
return dequantize_V_q5_0<T, ne>;
|
|
606
|
+
} else if constexpr (type_V == GGML_TYPE_Q5_1) {
|
|
607
|
+
return dequantize_V_q5_1<T, ne>;
|
|
608
|
+
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
|
|
609
|
+
return dequantize_V_q8_0<T, ne>;
|
|
610
|
+
} else {
|
|
611
|
+
static_assert(type_V == -1, "bad type");
|
|
612
|
+
return nullptr;
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
template <int ncols1, int warp_size>
|
|
617
|
+
static void flash_attn_mask_to_KV_max(const sycl::half2 * __restrict__ mask,
|
|
618
|
+
int * __restrict__ KV_max,
|
|
619
|
+
const int ne30,
|
|
620
|
+
const int s31,
|
|
621
|
+
const int s33,
|
|
622
|
+
int * buf_iw) {
|
|
623
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
624
|
+
const int ne31 = item_ct1.get_group_range(2);
|
|
625
|
+
const int tid = item_ct1.get_local_id(2);
|
|
626
|
+
const int sequence = item_ct1.get_group(1);
|
|
627
|
+
const int jt = item_ct1.get_group(2);
|
|
628
|
+
|
|
629
|
+
mask += sequence*s33 + jt*ncols1*s31;
|
|
630
|
+
|
|
631
|
+
if (tid < warp_size) {
|
|
632
|
+
buf_iw[tid] = 1;
|
|
633
|
+
}
|
|
634
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
635
|
+
|
|
636
|
+
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
|
|
637
|
+
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
|
|
638
|
+
int all_inf = 1;
|
|
639
|
+
|
|
640
|
+
#pragma unroll
|
|
641
|
+
for (int j = 0; j < ncols1; ++j) {
|
|
642
|
+
const sycl::float2 tmp =
|
|
643
|
+
mask[j * s31 + KV_max_sj / 2 + tid].template convert<float, sycl::rounding_mode::automatic>();
|
|
644
|
+
all_inf = all_inf && int(sycl::isinf((float) (tmp.x()))) && int(sycl::isinf((float) (tmp.y())));
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
all_inf = warp_reduce_all<warp_size>(all_inf);
|
|
648
|
+
if (tid % warp_size == 0) {
|
|
649
|
+
buf_iw[tid / warp_size] = all_inf;
|
|
650
|
+
}
|
|
651
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
652
|
+
all_inf = buf_iw[tid % warp_size];
|
|
653
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
654
|
+
all_inf = warp_reduce_all<warp_size>(all_inf);
|
|
655
|
+
|
|
656
|
+
if (!all_inf) {
|
|
657
|
+
break;
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
|
|
662
|
+
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
|
|
663
|
+
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
|
|
664
|
+
KV_max_sj += FATTN_KQ_STRIDE;
|
|
665
|
+
|
|
666
|
+
if (item_ct1.get_local_id(2) != 0) {
|
|
667
|
+
return;
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
KV_max[sequence*ne31 + jt] = KV_max_sj;
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
template <int D, int ncols1, int ncols2> // D == head size
|
|
674
|
+
|
|
675
|
+
static void flash_attn_stream_k_fixup(float * __restrict__ dst,
|
|
676
|
+
const sycl::float2 * __restrict__ dst_fixup,
|
|
677
|
+
const int ne01,
|
|
678
|
+
const int ne02,
|
|
679
|
+
const int ne03,
|
|
680
|
+
const int ne11,
|
|
681
|
+
const int ne12,
|
|
682
|
+
const int nbatch_fa) {
|
|
683
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
684
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
685
|
+
|
|
686
|
+
const int bidx0 = item_ct1.get_group(2);
|
|
687
|
+
const int j = item_ct1.get_group(1);
|
|
688
|
+
const int c = item_ct1.get_group(0);
|
|
689
|
+
const int jc = j*ncols2 + c;
|
|
690
|
+
const int tid = item_ct1.get_local_id(2);
|
|
691
|
+
|
|
692
|
+
const float * dst_fixup_data = ((const float *) dst_fixup) + item_ct1.get_group_range(2) * (2 * 2 * ncols);
|
|
693
|
+
|
|
694
|
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
695
|
+
|
|
696
|
+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
697
|
+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
698
|
+
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
|
699
|
+
|
|
700
|
+
const int kbc0 = int64_t(bidx0 + 0) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
|
|
701
|
+
const int kbc0_stop =
|
|
702
|
+
int64_t(bidx0 + 1) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
|
|
703
|
+
|
|
704
|
+
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
705
|
+
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
706
|
+
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
|
|
707
|
+
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
|
708
|
+
return;
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
|
712
|
+
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
|
|
713
|
+
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
|
714
|
+
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
|
715
|
+
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
|
716
|
+
|
|
717
|
+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
718
|
+
|
|
719
|
+
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
|
|
720
|
+
return;
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
|
|
724
|
+
|
|
725
|
+
// Load the partial result that needs a fixup:
|
|
726
|
+
float dst_val = 0.0f;
|
|
727
|
+
float max_val = 0.0f;
|
|
728
|
+
float rowsum = 0.0f;
|
|
729
|
+
{
|
|
730
|
+
dst_val = *dst;
|
|
731
|
+
|
|
732
|
+
const sycl::float2 tmp = dst_fixup[bidx0 * ncols + jc];
|
|
733
|
+
max_val = tmp.x();
|
|
734
|
+
rowsum = tmp.y();
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
// Iterate over previous blocks and compute the combined results.
|
|
738
|
+
// All SYCL blocks that get here must have a previous block that needs a fixup.
|
|
739
|
+
int bidx = bidx0 - 1;
|
|
740
|
+
int kbc_stop = kbc0;
|
|
741
|
+
while(true) {
|
|
742
|
+
const int kbc = int64_t(bidx) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
|
|
743
|
+
if (kbc == kbc_stop) { // Did not have any data.
|
|
744
|
+
bidx--;
|
|
745
|
+
kbc_stop = kbc;
|
|
746
|
+
continue;
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
|
750
|
+
|
|
751
|
+
const sycl::float2 tmp = dst_fixup[(item_ct1.get_group_range(2) + bidx) * ncols + jc];
|
|
752
|
+
|
|
753
|
+
// Scale the current and new value accumulators depending on the max. values.
|
|
754
|
+
const float max_val_new = sycl::fmax(max_val, tmp.x());
|
|
755
|
+
|
|
756
|
+
const float diff_val = max_val - max_val_new;
|
|
757
|
+
const float diff_add = tmp.x() - max_val_new;
|
|
758
|
+
|
|
759
|
+
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_val) : 0.0f;
|
|
760
|
+
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_add) : 0.0f;
|
|
761
|
+
|
|
762
|
+
dst_val = scale_val*dst_val + scale_add*dst_add;
|
|
763
|
+
rowsum = scale_val * rowsum + scale_add * tmp.y();
|
|
764
|
+
|
|
765
|
+
max_val = max_val_new;
|
|
766
|
+
|
|
767
|
+
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
|
768
|
+
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
|
769
|
+
break;
|
|
770
|
+
}
|
|
771
|
+
bidx--;
|
|
772
|
+
kbc_stop = kbc;
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
// Write back final result:
|
|
776
|
+
*dst = dst_val / rowsum;
|
|
777
|
+
}
|
|
778
|
+
|
|
779
|
+
template <int D> // D == head size
|
|
780
|
+
|
|
781
|
+
static void flash_attn_combine_results(const float * __restrict__ VKQ_parts,
|
|
782
|
+
const sycl::float2 * __restrict__ VKQ_meta,
|
|
783
|
+
float * __restrict__ dst,
|
|
784
|
+
const int parallel_blocks,
|
|
785
|
+
uint8_t * dpct_local) {
|
|
786
|
+
// Dimension 0: threadIdx.x
|
|
787
|
+
// Dimension 1: blockIdx.x
|
|
788
|
+
// Dimension 2: blockIdx.y
|
|
789
|
+
// Dimension 3: blockIdx.z
|
|
790
|
+
// Memory layout is permuted with [0, 2, 1, 3]
|
|
791
|
+
|
|
792
|
+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
|
793
|
+
const int ne01 = item_ct1.get_group_range(2);
|
|
794
|
+
const int ne02 = item_ct1.get_group_range(1);
|
|
795
|
+
|
|
796
|
+
const int col = item_ct1.get_group(2);
|
|
797
|
+
const int head = item_ct1.get_group(1);
|
|
798
|
+
const int sequence = item_ct1.get_group(0);
|
|
799
|
+
|
|
800
|
+
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
|
801
|
+
|
|
802
|
+
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
|
803
|
+
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
|
804
|
+
dst += j_dst_unrolled * D;
|
|
805
|
+
|
|
806
|
+
const int tid = item_ct1.get_local_id(2);
|
|
807
|
+
__builtin_assume(tid < D);
|
|
808
|
+
|
|
809
|
+
auto meta = (sycl::float2 *) dpct_local;
|
|
810
|
+
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
|
811
|
+
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
|
812
|
+
}
|
|
813
|
+
|
|
814
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
815
|
+
|
|
816
|
+
float kqmax = meta[0].x();
|
|
817
|
+
for (int l = 1; l < parallel_blocks; ++l) {
|
|
818
|
+
kqmax = sycl::max(kqmax, meta[l].x());
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
float VKQ_numerator = 0.0f;
|
|
822
|
+
float VKQ_denominator = 0.0f;
|
|
823
|
+
for (int l = 0; l < parallel_blocks; ++l) {
|
|
824
|
+
const float KQ_max_scale = sycl::native::exp(meta[l].x() - kqmax);
|
|
825
|
+
|
|
826
|
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
|
827
|
+
VKQ_denominator += KQ_max_scale * meta[l].y();
|
|
828
|
+
}
|
|
829
|
+
|
|
830
|
+
dst[tid] = VKQ_numerator / VKQ_denominator;
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
template <fattn_kernel_t fattn_kernel, int warp_size>
|
|
834
|
+
static void lauch_kernel(
|
|
835
|
+
dpct::dim3 group_range,
|
|
836
|
+
dpct::dim3 local_range,
|
|
837
|
+
queue_ptr q,
|
|
838
|
+
unsigned int local_mem_size,
|
|
839
|
+
const char* __restrict__ Q,
|
|
840
|
+
const char* __restrict__ K,
|
|
841
|
+
const char* __restrict__ V,
|
|
842
|
+
const char* __restrict__ mask,
|
|
843
|
+
const char* __restrict__ sinks,
|
|
844
|
+
const int* __restrict__ KV_max,
|
|
845
|
+
float* __restrict__ dst,
|
|
846
|
+
sycl::float2* __restrict__ dst_meta,
|
|
847
|
+
const float scale,
|
|
848
|
+
const float max_bias,
|
|
849
|
+
const float m0,
|
|
850
|
+
const float m1,
|
|
851
|
+
const uint32_t n_head_log2,
|
|
852
|
+
const float logit_softcap,
|
|
853
|
+
const int32_t ne00,
|
|
854
|
+
const sycl::uint3 ne01,
|
|
855
|
+
const int32_t ne02,
|
|
856
|
+
const int32_t ne03,
|
|
857
|
+
const int32_t nb01,
|
|
858
|
+
const int32_t nb02,
|
|
859
|
+
const int32_t nb03,
|
|
860
|
+
const int32_t ne10,
|
|
861
|
+
const int32_t ne11,
|
|
862
|
+
const int32_t ne12,
|
|
863
|
+
const int32_t ne13,
|
|
864
|
+
const int32_t nb11,
|
|
865
|
+
const int32_t nb12,
|
|
866
|
+
const int64_t nb13,
|
|
867
|
+
const int32_t nb21,
|
|
868
|
+
const int32_t nb22,
|
|
869
|
+
const int64_t nb23,
|
|
870
|
+
const int32_t ne31,
|
|
871
|
+
const int32_t ne32,
|
|
872
|
+
const int32_t ne33,
|
|
873
|
+
const int32_t nb31,
|
|
874
|
+
const int32_t nb32,
|
|
875
|
+
const int64_t nb33) {
|
|
876
|
+
GGML_UNUSED(local_mem_size);
|
|
877
|
+
q->submit([&](sycl::handler &cgh) {
|
|
878
|
+
cgh.parallel_for(
|
|
879
|
+
sycl::nd_range<3>(
|
|
880
|
+
static_cast<sycl::range<3>>(group_range * local_range),
|
|
881
|
+
static_cast<sycl::range<3>>(local_range)),
|
|
882
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
|
|
883
|
+
GGML_UNUSED(item_ct1);
|
|
884
|
+
fattn_kernel(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
885
|
+
max_bias, m0, m1, n_head_log2, logit_softcap, ne00,
|
|
886
|
+
ne01, ne02, ne03, nb01, nb02, nb03, ne10, ne11,
|
|
887
|
+
ne12, ne13, nb11, nb12, nb13, nb21, nb22, nb23,
|
|
888
|
+
ne31, ne32, ne33, nb31, nb32, nb33);
|
|
889
|
+
});
|
|
890
|
+
});
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
template <int DV, int ncols1, int ncols2, fattn_kernel_t fattn_kernel, int warp_size>
|
|
894
|
+
void launch_fattn(
|
|
895
|
+
ggml_backend_sycl_context & ctx, ggml_tensor * dst, const int nwarps, const size_t nbytes_shared,
|
|
896
|
+
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k) {
|
|
897
|
+
|
|
898
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
899
|
+
|
|
900
|
+
const ggml_tensor * Q = dst->src[0];
|
|
901
|
+
const ggml_tensor * K = dst->src[1];
|
|
902
|
+
const ggml_tensor * V = dst->src[2];
|
|
903
|
+
|
|
904
|
+
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
|
|
905
|
+
|
|
906
|
+
const ggml_tensor * mask = dst->src[3];
|
|
907
|
+
const ggml_tensor * sinks = dst->src[4];
|
|
908
|
+
|
|
909
|
+
ggml_tensor * KQV = dst;
|
|
910
|
+
|
|
911
|
+
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
|
912
|
+
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
|
913
|
+
|
|
914
|
+
GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
|
|
915
|
+
GGML_ASSERT(K->nb[0] == ggml_element_size(K));
|
|
916
|
+
GGML_ASSERT(V->nb[0] == ggml_element_size(V));
|
|
917
|
+
|
|
918
|
+
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
|
919
|
+
|
|
920
|
+
ggml_sycl_pool & pool = ctx.pool();
|
|
921
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
922
|
+
const int id = ggml_sycl_get_device();
|
|
923
|
+
const int nsm = ggml_sycl_info().devices[id].nsm;
|
|
924
|
+
|
|
925
|
+
ggml_sycl_pool_alloc<sycl::half> K_f16(pool);
|
|
926
|
+
ggml_sycl_pool_alloc<sycl::half> V_f16(pool);
|
|
927
|
+
ggml_sycl_pool_alloc<int> KV_max(pool);
|
|
928
|
+
ggml_sycl_pool_alloc<float> dst_tmp(pool);
|
|
929
|
+
ggml_sycl_pool_alloc<sycl::float2> dst_tmp_meta(pool);
|
|
930
|
+
|
|
931
|
+
const char * K_data = (const char *) K->data;
|
|
932
|
+
size_t nb11 = K->nb[1];
|
|
933
|
+
size_t nb12 = K->nb[2];
|
|
934
|
+
size_t nb13 = K->nb[3];
|
|
935
|
+
|
|
936
|
+
const char * V_data = (const char *) V->data;
|
|
937
|
+
size_t nb21 = V->nb[1];
|
|
938
|
+
size_t nb22 = V->nb[2];
|
|
939
|
+
size_t nb23 = V->nb[3];
|
|
940
|
+
|
|
941
|
+
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
|
942
|
+
const size_t bs = ggml_blck_size(K->type);
|
|
943
|
+
const size_t ts = ggml_type_size(K->type);
|
|
944
|
+
|
|
945
|
+
K_f16.alloc(ggml_nelements(K));
|
|
946
|
+
if (ggml_is_contiguously_allocated(K)) {
|
|
947
|
+
to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(K->type, dst);
|
|
948
|
+
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
|
949
|
+
|
|
950
|
+
nb11 = nb11 * bs * sizeof(sycl::half) / ts;
|
|
951
|
+
nb12 = nb12 * bs * sizeof(sycl::half) / ts;
|
|
952
|
+
nb13 = nb13 * bs * sizeof(sycl::half) / ts;
|
|
953
|
+
} else {
|
|
954
|
+
GGML_ASSERT(K->nb[0] == ts);
|
|
955
|
+
to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(K->type);
|
|
956
|
+
const int64_t s01 = nb11 / ts;
|
|
957
|
+
const int64_t s02 = nb12 / ts;
|
|
958
|
+
const int64_t s03 = nb13 / ts;
|
|
959
|
+
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
|
960
|
+
|
|
961
|
+
nb11 = K->ne[0] * sizeof(sycl::half);
|
|
962
|
+
nb12 = K->ne[1] * nb11;
|
|
963
|
+
nb13 = K->ne[2] * nb12;
|
|
964
|
+
}
|
|
965
|
+
K_data = (char *) K_f16.ptr;
|
|
966
|
+
}
|
|
967
|
+
|
|
968
|
+
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
|
969
|
+
if (V_is_K_view) {
|
|
970
|
+
V_data = K_data;
|
|
971
|
+
nb21 = nb11;
|
|
972
|
+
nb22 = nb12;
|
|
973
|
+
nb23 = nb13;
|
|
974
|
+
} else {
|
|
975
|
+
const size_t bs = ggml_blck_size(V->type);
|
|
976
|
+
const size_t ts = ggml_type_size(V->type);
|
|
977
|
+
|
|
978
|
+
V_f16.alloc(ggml_nelements(V));
|
|
979
|
+
if (ggml_is_contiguously_allocated(V)) {
|
|
980
|
+
to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(V->type, dst);
|
|
981
|
+
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
|
982
|
+
V_data = (char *) V_f16.ptr;
|
|
983
|
+
|
|
984
|
+
nb21 = nb21 * bs * sizeof(sycl::half) / ts;
|
|
985
|
+
nb22 = nb22 * bs * sizeof(sycl::half) / ts;
|
|
986
|
+
nb23 = nb23 * bs * sizeof(sycl::half) / ts;
|
|
987
|
+
} else {
|
|
988
|
+
GGML_ASSERT(V->nb[0] == ts);
|
|
989
|
+
to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(V->type);
|
|
990
|
+
const int64_t s01 = nb21 / ts;
|
|
991
|
+
const int64_t s02 = nb22 / ts;
|
|
992
|
+
const int64_t s03 = nb23 / ts;
|
|
993
|
+
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
|
994
|
+
|
|
995
|
+
nb21 = V->ne[0] * sizeof(sycl::half);
|
|
996
|
+
nb22 = V->ne[1] * nb21;
|
|
997
|
+
nb23 = V->ne[2] * nb22;
|
|
998
|
+
}
|
|
999
|
+
V_data = (char *) V_f16.ptr;
|
|
1000
|
+
}
|
|
1001
|
+
}
|
|
1002
|
+
|
|
1003
|
+
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
|
1004
|
+
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
|
1005
|
+
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
|
|
1006
|
+
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
|
|
1007
|
+
|
|
1008
|
+
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
|
1009
|
+
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
|
1010
|
+
// multiple sequences of possibly different lengths.
|
|
1011
|
+
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
|
1012
|
+
const int s31 = mask->nb[1] / sizeof(sycl::half2);
|
|
1013
|
+
const int s33 = mask->nb[3] / sizeof(sycl::half2);
|
|
1014
|
+
|
|
1015
|
+
const dpct::dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
|
|
1016
|
+
const dpct::dim3 block_dim_KV_max(FATTN_KQ_STRIDE / 2, 1, 1);
|
|
1017
|
+
|
|
1018
|
+
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
|
|
1019
|
+
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
|
1020
|
+
|
|
1021
|
+
KV_max.alloc(ne_KV_max);
|
|
1022
|
+
{
|
|
1023
|
+
dpct::has_capability_or_fail(main_stream->get_device(), { sycl::aspect::fp16 });
|
|
1024
|
+
|
|
1025
|
+
main_stream->submit([&](sycl::handler & cgh) {
|
|
1026
|
+
sycl::local_accessor<int, 1> buf_iw_acc_ct1(sycl::range<1>(warp_size), cgh);
|
|
1027
|
+
|
|
1028
|
+
auto mask_data_ct0 = (const sycl::half2 *) mask->data;
|
|
1029
|
+
auto KV_max_ptr_ct1 = KV_max.ptr;
|
|
1030
|
+
|
|
1031
|
+
cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),
|
|
1032
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
1033
|
+
GGML_UNUSED(item_ct1);
|
|
1034
|
+
flash_attn_mask_to_KV_max<ncols1, warp_size>(
|
|
1035
|
+
mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,
|
|
1036
|
+
buf_iw_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
|
|
1037
|
+
});
|
|
1038
|
+
});
|
|
1039
|
+
}
|
|
1040
|
+
SYCL_CHECK(0);
|
|
1041
|
+
}
|
|
1042
|
+
|
|
1043
|
+
const dpct::dim3 block_dim(warp_size, nwarps, 1);
|
|
1044
|
+
|
|
1045
|
+
// Max. number of active blocks limited by occupancy.
|
|
1046
|
+
int max_blocks_per_sm = ggml_sycl_info().devices[id].max_wg_per_cu;
|
|
1047
|
+
int parallel_blocks = max_blocks_per_sm;
|
|
1048
|
+
dpct::dim3 blocks_num;
|
|
1049
|
+
if (stream_k) {
|
|
1050
|
+
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
|
1051
|
+
const int max_blocks = max_blocks_per_sm*nsm;
|
|
1052
|
+
const int nblocks_stream_k = max_blocks;
|
|
1053
|
+
const bool use_stream_k = true;
|
|
1054
|
+
|
|
1055
|
+
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
|
1056
|
+
blocks_num.y = 1;
|
|
1057
|
+
blocks_num.z = 1;
|
|
1058
|
+
|
|
1059
|
+
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
|
1060
|
+
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
|
|
1061
|
+
}
|
|
1062
|
+
} else {
|
|
1063
|
+
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
|
|
1064
|
+
|
|
1065
|
+
// parallel_blocks must not be larger than what the tensor size allows:
|
|
1066
|
+
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
|
1067
|
+
// todo fix the hard code change
|
|
1068
|
+
// parallel_blocks = ntiles_KQ;
|
|
1069
|
+
|
|
1070
|
+
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
|
|
1071
|
+
// Test whether parallel_blocks can be set to a higher value for better efficiency.
|
|
1072
|
+
const int blocks_per_wave = nsm * max_blocks_per_sm;
|
|
1073
|
+
int nwaves_best = 0;
|
|
1074
|
+
int efficiency_percent_best = 0;
|
|
1075
|
+
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
|
|
1076
|
+
const int nblocks_total = ntiles_total * parallel_blocks_test;
|
|
1077
|
+
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
|
|
1078
|
+
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
|
1079
|
+
|
|
1080
|
+
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
|
1081
|
+
if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
|
|
1082
|
+
break;
|
|
1083
|
+
}
|
|
1084
|
+
|
|
1085
|
+
if (efficiency_percent > efficiency_percent_best) {
|
|
1086
|
+
nwaves_best = nwaves;
|
|
1087
|
+
efficiency_percent_best = efficiency_percent;
|
|
1088
|
+
parallel_blocks = parallel_blocks_test;
|
|
1089
|
+
}
|
|
1090
|
+
}
|
|
1091
|
+
|
|
1092
|
+
blocks_num.x = ntiles_x;
|
|
1093
|
+
blocks_num.y = parallel_blocks;
|
|
1094
|
+
blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
|
|
1095
|
+
|
|
1096
|
+
if (parallel_blocks > 1) {
|
|
1097
|
+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
|
1098
|
+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
|
1099
|
+
}
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
float scale = 1.0f;
|
|
1103
|
+
float max_bias = 0.0f;
|
|
1104
|
+
float logit_softcap = 0.0f;
|
|
1105
|
+
|
|
1106
|
+
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
|
|
1107
|
+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
|
1108
|
+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
1109
|
+
|
|
1110
|
+
if (logit_softcap != 0.0f) {
|
|
1111
|
+
scale /= logit_softcap;
|
|
1112
|
+
}
|
|
1113
|
+
|
|
1114
|
+
const uint32_t n_head = Q->ne[2];
|
|
1115
|
+
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
|
|
1116
|
+
|
|
1117
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
1118
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1119
|
+
|
|
1120
|
+
// TODO other tensor dimensions after removal of WMMA kernel:
|
|
1121
|
+
const sycl::uint3 ne01 = init_fastdiv_values(Q->ne[1]);
|
|
1122
|
+
|
|
1123
|
+
GGML_ASSERT(block_dim.x % warp_size == 0);
|
|
1124
|
+
|
|
1125
|
+
lauch_kernel<fattn_kernel, warp_size>(
|
|
1126
|
+
blocks_num, block_dim, main_stream, (unsigned int) nbytes_shared, (const char *) Q->data, K_data, V_data,
|
|
1127
|
+
mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr,
|
|
1128
|
+
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, (sycl::float2 *)dst_tmp_meta.ptr, scale, max_bias, m0, m1,
|
|
1129
|
+
n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0],
|
|
1130
|
+
K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0,
|
|
1131
|
+
mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
|
1132
|
+
mask ? mask->nb[3] : 0);
|
|
1133
|
+
SYCL_CHECK(0);
|
|
1134
|
+
|
|
1135
|
+
if (stream_k) {
|
|
1136
|
+
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
|
1137
|
+
const dpct::dim3 block_dim_combine(DV, 1, 1);
|
|
1138
|
+
const dpct::dim3 blocks_num_combine = { blocks_num.x, ncols1, ncols2 };
|
|
1139
|
+
|
|
1140
|
+
main_stream->submit([&](sycl::handler & cgh) {
|
|
1141
|
+
auto KQV_data_ct0 = (float *) KQV->data;
|
|
1142
|
+
auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
|
|
1143
|
+
auto Q_ne_ct2 = Q->ne[1];
|
|
1144
|
+
auto Q_ne_ct3 = Q->ne[2];
|
|
1145
|
+
auto Q_ne_ct4 = Q->ne[3];
|
|
1146
|
+
auto K_ne_ct5 = K->ne[1];
|
|
1147
|
+
auto K_ne_ct6 = K->ne[2];
|
|
1148
|
+
|
|
1149
|
+
cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
|
|
1150
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
1151
|
+
GGML_UNUSED(item_ct1);
|
|
1152
|
+
flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1,
|
|
1153
|
+
Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,
|
|
1154
|
+
K_ne_ct5, K_ne_ct6, nbatch_fa);
|
|
1155
|
+
});
|
|
1156
|
+
});
|
|
1157
|
+
}
|
|
1158
|
+
} else if (parallel_blocks > 1) {
|
|
1159
|
+
const dpct::dim3 block_dim_combine(DV, 1, 1);
|
|
1160
|
+
const dpct::dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
|
1161
|
+
const size_t nbytes_shared_combine = parallel_blocks * sizeof(sycl::float2);
|
|
1162
|
+
main_stream->submit([&](sycl::handler & cgh) {
|
|
1163
|
+
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(nbytes_shared_combine), cgh);
|
|
1164
|
+
|
|
1165
|
+
auto dst_tmp_ptr_ct0 = dst_tmp.ptr;
|
|
1166
|
+
auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
|
|
1167
|
+
auto KQV_data_ct2 = (float *) KQV->data;
|
|
1168
|
+
|
|
1169
|
+
cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
|
|
1170
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
1171
|
+
GGML_UNUSED(item_ct1);
|
|
1172
|
+
flash_attn_combine_results<DV>(
|
|
1173
|
+
dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,
|
|
1174
|
+
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
|
|
1175
|
+
});
|
|
1176
|
+
});
|
|
1177
|
+
}
|
|
1178
|
+
SYCL_CHECK(0);
|
|
1179
|
+
}
|