whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -2,166 +2,288 @@
|
|
|
2
2
|
#pragma clang diagnostic ignored "-Wunused-function"
|
|
3
3
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
|
4
4
|
|
|
5
|
-
#
|
|
6
|
-
# define FARF_HIGH 1
|
|
7
|
-
#endif
|
|
5
|
+
#include <assert.h>
|
|
8
6
|
#include <HAP_farf.h>
|
|
9
|
-
#include <HAP_mem.h>
|
|
10
7
|
#include <HAP_perf.h>
|
|
11
|
-
#include <hexagon_protos.h>
|
|
12
|
-
#include <hexagon_types.h>
|
|
13
8
|
#include <math.h>
|
|
14
9
|
#include <string.h>
|
|
15
10
|
|
|
11
|
+
#include "hex-dma.h"
|
|
12
|
+
#include "hvx-utils.h"
|
|
13
|
+
#include "hvx-dump.h"
|
|
14
|
+
|
|
16
15
|
#define GGML_COMMON_DECL_C
|
|
17
16
|
#include "ggml-common.h"
|
|
18
17
|
#include "htp-ctx.h"
|
|
19
|
-
#include "htp-dma.h"
|
|
20
18
|
#include "htp-msg.h"
|
|
21
19
|
#include "htp-ops.h"
|
|
22
|
-
#include "hvx-utils.h"
|
|
23
|
-
#include "ops-utils.h"
|
|
24
20
|
|
|
25
|
-
//
|
|
26
|
-
|
|
27
|
-
|
|
21
|
+
// Must be multiple of 32
|
|
22
|
+
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
|
23
|
+
|
|
24
|
+
// This is a bit of a hack because the compiler is strugling to properly inline
|
|
25
|
+
// the default hvx_vec_f32_to_f16 with output into the local array.
|
|
26
|
+
static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
|
27
|
+
{
|
|
28
|
+
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
// Dot product of two F16 vectors, accumulating to float
|
|
32
|
+
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
|
28
33
|
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
|
34
|
+
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
|
29
35
|
|
|
30
36
|
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
31
37
|
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
32
38
|
|
|
33
|
-
|
|
34
|
-
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
|
39
|
+
HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
|
35
40
|
|
|
36
41
|
uint32_t i = 0;
|
|
37
42
|
|
|
38
43
|
#pragma unroll(4)
|
|
39
44
|
for (i = 0; i < nvec; i++) {
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
|
43
|
-
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
|
44
|
-
|
|
45
|
-
// Load x (fp16)
|
|
46
|
-
HVX_Vector x_hf = vx[i];
|
|
45
|
+
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
|
|
46
|
+
}
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
if (nloe) {
|
|
49
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
|
50
|
+
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
|
51
|
+
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
|
49
52
|
|
|
50
|
-
|
|
53
|
+
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
|
|
51
54
|
}
|
|
52
55
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
|
56
|
+
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
|
|
57
|
+
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));
|
|
58
|
+
hvx_vec_store_u(r, 4, rsum);
|
|
59
|
+
}
|
|
58
60
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
+
static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
|
|
62
|
+
const uint8_t * restrict x,
|
|
63
|
+
const size_t stride_x,
|
|
64
|
+
const size_t nvec,
|
|
65
|
+
const size_t nloe) {
|
|
66
|
+
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
|
|
67
|
+
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
|
|
68
|
+
const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
|
|
69
|
+
const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
|
|
70
|
+
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
|
71
|
+
|
|
72
|
+
HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
|
73
|
+
HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
|
74
|
+
HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
|
75
|
+
HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
|
61
76
|
|
|
62
|
-
|
|
63
|
-
// Note that we need to clear both x and y because they may contain NANs
|
|
64
|
-
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
|
65
|
-
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
|
66
|
-
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
|
77
|
+
uint32_t i = 0;
|
|
67
78
|
|
|
68
|
-
|
|
79
|
+
for (i = 0; i < nvec; i++) {
|
|
80
|
+
HVX_Vector y_hf = vy[i];
|
|
81
|
+
HVX_Vector x0_hf = vx0[i];
|
|
82
|
+
HVX_Vector x1_hf = vx1[i];
|
|
83
|
+
HVX_Vector x2_hf = vx2[i];
|
|
84
|
+
HVX_Vector x3_hf = vx3[i];
|
|
85
|
+
|
|
86
|
+
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
|
87
|
+
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
|
88
|
+
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
|
|
89
|
+
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
|
|
90
|
+
}
|
|
69
91
|
|
|
70
|
-
|
|
92
|
+
if (nloe) {
|
|
93
|
+
// Load x (fp16) and zero-out unused elements
|
|
94
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
|
95
|
+
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
|
96
|
+
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
|
97
|
+
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
|
98
|
+
HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
|
|
99
|
+
HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
|
|
100
|
+
|
|
101
|
+
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
|
102
|
+
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
|
103
|
+
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
|
|
104
|
+
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
|
|
71
105
|
}
|
|
72
106
|
|
|
73
|
-
|
|
74
|
-
|
|
107
|
+
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
|
|
108
|
+
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
|
|
109
|
+
HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));
|
|
110
|
+
HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));
|
|
75
111
|
|
|
76
|
-
|
|
112
|
+
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
|
|
113
|
+
return hvx_vec_reduce_sum_f32x4(rsum0123);
|
|
77
114
|
}
|
|
78
115
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
116
|
+
static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
|
|
117
|
+
const uint8_t * restrict x,
|
|
118
|
+
const size_t stride_x,
|
|
119
|
+
const size_t n,
|
|
120
|
+
float s) {
|
|
121
|
+
|
|
122
|
+
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
123
|
+
const size_t nloe = n % VLEN_FP16; // leftover elements
|
|
124
|
+
|
|
125
|
+
HVX_Vector sums; // initialize at j = 0
|
|
126
|
+
const size_t stride_x_4 = stride_x * 4;
|
|
127
|
+
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
|
|
128
|
+
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
|
|
129
|
+
HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
|
|
130
|
+
sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
|
|
131
|
+
x += stride_x_4;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
|
|
135
|
+
return Q6_Vsf_equals_Vqf32(sums);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
// MAD: y (F32) += x (F16) * s (F16)
|
|
139
|
+
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) {
|
|
140
|
+
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
|
|
141
|
+
|
|
142
|
+
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
|
|
143
|
+
HVX_Vector * restrict vy = (HVX_Vector *) y;
|
|
83
144
|
|
|
84
145
|
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
85
146
|
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
86
147
|
|
|
87
|
-
|
|
88
|
-
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
|
148
|
+
HVX_Vector S0 = hvx_vec_splat_f16(*s);
|
|
89
149
|
|
|
90
150
|
uint32_t i = 0;
|
|
91
151
|
|
|
92
|
-
#pragma unroll(
|
|
93
|
-
for (i = 0; i < nvec; i
|
|
94
|
-
|
|
95
|
-
HVX_Vector x_hf = vx[i];
|
|
96
|
-
|
|
97
|
-
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
|
98
|
-
|
|
99
|
-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
|
152
|
+
#pragma unroll(2)
|
|
153
|
+
for (i = 0; i < nvec; ++i) {
|
|
154
|
+
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
|
100
155
|
}
|
|
101
156
|
|
|
102
157
|
if (nloe) {
|
|
103
|
-
|
|
158
|
+
HVX_VectorPair xy_p = vy_p[i];
|
|
159
|
+
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
|
104
160
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
|
161
|
+
HVX_Vector xy = Q6_V_lo_W(xy_p);
|
|
162
|
+
i = 2 * i; // index for vy
|
|
108
163
|
|
|
109
|
-
|
|
164
|
+
if (nloe >= VLEN_FP32) {
|
|
165
|
+
vy[i] = xy;
|
|
166
|
+
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
|
|
167
|
+
}
|
|
110
168
|
|
|
111
|
-
|
|
169
|
+
if (nloe) {
|
|
170
|
+
hvx_vec_store_a(&vy[i], nloe * 4, xy);
|
|
171
|
+
}
|
|
112
172
|
}
|
|
113
|
-
|
|
114
|
-
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
|
115
|
-
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
|
116
|
-
hvx_vec_store_u(r, 4, rsum);
|
|
117
173
|
}
|
|
118
174
|
|
|
119
|
-
// MAD: y (F32) +=
|
|
120
|
-
static inline void
|
|
121
|
-
|
|
122
|
-
HVX_Vector * restrict
|
|
175
|
+
// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
|
|
176
|
+
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
|
|
177
|
+
const __fp16 * restrict s0, const __fp16 * restrict s1, int n) {
|
|
178
|
+
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
|
|
179
|
+
const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
|
|
123
180
|
|
|
124
|
-
|
|
125
|
-
|
|
181
|
+
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
|
|
182
|
+
HVX_Vector * restrict vy = (HVX_Vector *) y;
|
|
126
183
|
|
|
127
|
-
|
|
184
|
+
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
185
|
+
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
186
|
+
|
|
187
|
+
HVX_Vector S0 = hvx_vec_splat_f16(*s0);
|
|
188
|
+
HVX_Vector S1 = hvx_vec_splat_f16(*s1);
|
|
128
189
|
|
|
129
190
|
uint32_t i = 0;
|
|
130
|
-
|
|
191
|
+
|
|
192
|
+
#pragma unroll(2)
|
|
131
193
|
for (i = 0; i < nvec; ++i) {
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
|
|
135
|
-
ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
|
|
194
|
+
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
|
195
|
+
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
|
|
136
196
|
}
|
|
137
197
|
|
|
138
198
|
if (nloe) {
|
|
139
|
-
HVX_VectorPair
|
|
199
|
+
HVX_VectorPair xy_p = vy_p[i];
|
|
200
|
+
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
|
201
|
+
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
|
|
140
202
|
|
|
141
|
-
HVX_Vector
|
|
142
|
-
i = 2 * i;
|
|
203
|
+
HVX_Vector xy = Q6_V_lo_W(xy_p);
|
|
204
|
+
i = 2 * i; // index for vy
|
|
143
205
|
|
|
144
|
-
if (nloe >=
|
|
145
|
-
|
|
146
|
-
nloe -=
|
|
206
|
+
if (nloe >= VLEN_FP32) {
|
|
207
|
+
vy[i] = xy;
|
|
208
|
+
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
|
|
147
209
|
}
|
|
148
210
|
|
|
149
211
|
if (nloe) {
|
|
150
|
-
|
|
151
|
-
hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
|
|
212
|
+
hvx_vec_store_a(&vy[i], nloe * 4, xy);
|
|
152
213
|
}
|
|
153
214
|
}
|
|
154
215
|
}
|
|
155
216
|
|
|
156
|
-
|
|
217
|
+
struct htp_fa_context {
|
|
218
|
+
const struct htp_ops_context * octx;
|
|
219
|
+
|
|
220
|
+
struct fastdiv_values src0_div21;
|
|
221
|
+
struct fastdiv_values src0_div1;
|
|
222
|
+
|
|
223
|
+
struct fastdiv_values broadcast_rk2;
|
|
224
|
+
struct fastdiv_values broadcast_rk3;
|
|
225
|
+
struct fastdiv_values broadcast_rv2;
|
|
226
|
+
struct fastdiv_values broadcast_rv3;
|
|
227
|
+
|
|
228
|
+
struct fastdiv_values src3_div2;
|
|
229
|
+
struct fastdiv_values src3_div3;
|
|
230
|
+
|
|
231
|
+
float scale;
|
|
232
|
+
float max_bias;
|
|
233
|
+
float logit_softcap;
|
|
234
|
+
|
|
235
|
+
uint32_t n_head_log2;
|
|
236
|
+
float m0;
|
|
237
|
+
float m1;
|
|
238
|
+
|
|
239
|
+
uint32_t n_blocks;
|
|
240
|
+
|
|
241
|
+
size_t size_q_row_padded;
|
|
242
|
+
size_t size_k_row_padded;
|
|
243
|
+
size_t size_v_row_padded;
|
|
244
|
+
|
|
245
|
+
size_t size_k_block;
|
|
246
|
+
size_t size_v_block;
|
|
247
|
+
size_t size_m_block;
|
|
157
248
|
|
|
158
|
-
|
|
249
|
+
uint32_t qrows;
|
|
250
|
+
uint32_t qrows_per_thread;
|
|
251
|
+
|
|
252
|
+
bool is_q_fp32;
|
|
253
|
+
|
|
254
|
+
uint64_t t_start;
|
|
255
|
+
};
|
|
256
|
+
|
|
257
|
+
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
|
|
258
|
+
assert((size_t) dst % 128 == 0);
|
|
259
|
+
assert((size_t) src % 128 == 0);
|
|
260
|
+
|
|
261
|
+
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
|
|
262
|
+
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
|
|
263
|
+
|
|
264
|
+
const uint32_t nvec = n / VLEN_FP32;
|
|
265
|
+
const uint32_t nloe = n % VLEN_FP32;
|
|
266
|
+
|
|
267
|
+
uint32_t i = 0;
|
|
268
|
+
#pragma unroll(4)
|
|
269
|
+
for (; i < nvec; ++i) {
|
|
270
|
+
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
|
|
271
|
+
}
|
|
272
|
+
if (nloe) {
|
|
273
|
+
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
|
274
|
+
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
|
|
279
|
+
struct htp_fa_context * factx = (struct htp_fa_context *) data;
|
|
280
|
+
const struct htp_ops_context * octx = factx->octx;
|
|
159
281
|
const struct htp_tensor * q = &octx->src0;
|
|
160
282
|
const struct htp_tensor * k = &octx->src1;
|
|
161
283
|
const struct htp_tensor * v = &octx->src2;
|
|
162
284
|
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
|
163
285
|
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
|
164
|
-
struct htp_tensor * dst = &octx->dst;
|
|
286
|
+
const struct htp_tensor * dst = &octx->dst;
|
|
165
287
|
|
|
166
288
|
const uint32_t neq0 = q->ne[0];
|
|
167
289
|
const uint32_t neq1 = q->ne[1];
|
|
@@ -198,22 +320,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|
|
198
320
|
const uint32_t nb2 = dst->nb[2];
|
|
199
321
|
const uint32_t nb3 = dst->nb[3];
|
|
200
322
|
|
|
201
|
-
float scale = 1.0f;
|
|
202
|
-
float max_bias = 0.0f;
|
|
203
|
-
float logit_softcap = 0.0f;
|
|
204
|
-
|
|
205
|
-
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
|
206
|
-
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
|
207
|
-
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
|
208
|
-
|
|
209
|
-
if (logit_softcap != 0) {
|
|
210
|
-
scale /= logit_softcap;
|
|
211
|
-
}
|
|
212
|
-
|
|
213
323
|
// total rows in q
|
|
214
|
-
const uint32_t nr =
|
|
215
|
-
|
|
216
|
-
const uint32_t dr = (nr + nth - 1) / nth;
|
|
324
|
+
const uint32_t nr = factx->qrows;
|
|
325
|
+
const uint32_t dr = factx->qrows_per_thread;
|
|
217
326
|
const uint32_t ir0 = dr * ith;
|
|
218
327
|
const uint32_t ir1 = MIN(ir0 + dr, nr);
|
|
219
328
|
|
|
@@ -225,18 +334,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|
|
225
334
|
const uint32_t DV = nev0;
|
|
226
335
|
|
|
227
336
|
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
|
228
|
-
const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
|
|
229
|
-
|
|
230
337
|
const size_t size_k_row = DK * sizeof(__fp16);
|
|
231
338
|
const size_t size_v_row = DV * sizeof(__fp16);
|
|
232
|
-
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
|
|
233
|
-
|
|
234
|
-
const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
|
|
235
|
-
const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
|
|
236
|
-
|
|
237
|
-
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
|
238
|
-
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
|
239
|
-
const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
|
240
339
|
|
|
241
340
|
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
|
242
341
|
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
|
@@ -245,72 +344,79 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|
|
245
344
|
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
|
|
246
345
|
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
|
|
247
346
|
|
|
248
|
-
const
|
|
249
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
|
250
|
-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
251
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
347
|
+
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
|
|
252
348
|
|
|
253
349
|
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
|
254
|
-
const uint32_t iq3 = fastdiv(ir, &
|
|
255
|
-
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &
|
|
350
|
+
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
|
|
351
|
+
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
|
|
256
352
|
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
|
|
257
353
|
|
|
258
|
-
const uint32_t ik3 = fastdiv(iq3, &
|
|
259
|
-
const uint32_t ik2 = fastdiv(iq2, &
|
|
354
|
+
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
|
|
355
|
+
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
|
|
260
356
|
|
|
261
|
-
const uint32_t iv3 = fastdiv(iq3, &
|
|
262
|
-
const uint32_t iv2 = fastdiv(iq2, &
|
|
357
|
+
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
|
|
358
|
+
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
|
|
263
359
|
|
|
264
360
|
// Fetch Q row
|
|
265
361
|
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
|
|
266
|
-
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
|
|
267
|
-
|
|
268
|
-
const uint32_t h = iq2; // head index
|
|
269
|
-
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
|
|
270
|
-
|
|
271
|
-
float S = 0.0f; // sum
|
|
272
|
-
float M = -INFINITY; // maximum KQ value
|
|
362
|
+
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
|
|
273
363
|
|
|
274
|
-
//
|
|
275
|
-
|
|
276
|
-
memset(VKQ32, 0, DV * sizeof(float));
|
|
364
|
+
// FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row,
|
|
365
|
+
// (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
|
|
277
366
|
|
|
278
367
|
const __fp16 * mp_base = NULL;
|
|
279
368
|
if (mask) {
|
|
280
|
-
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &
|
|
281
|
-
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &
|
|
369
|
+
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
|
|
370
|
+
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
|
|
282
371
|
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
|
|
283
372
|
}
|
|
284
373
|
|
|
285
|
-
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
|
286
|
-
|
|
287
374
|
// Prefetch first two blocks
|
|
288
|
-
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
|
|
375
|
+
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
|
|
289
376
|
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
|
290
377
|
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
|
291
378
|
|
|
292
379
|
// K
|
|
293
380
|
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
|
294
|
-
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
|
|
295
|
-
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
|
|
381
|
+
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
|
|
382
|
+
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
|
|
296
383
|
|
|
297
384
|
// V
|
|
298
385
|
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
|
299
|
-
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
|
|
300
|
-
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
|
|
386
|
+
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
|
|
387
|
+
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
|
|
301
388
|
|
|
302
389
|
// Mask
|
|
303
390
|
if (mask) {
|
|
304
391
|
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
|
305
|
-
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
|
|
392
|
+
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
|
|
306
393
|
// Mask is 1D contiguous for this row
|
|
307
394
|
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
|
308
395
|
}
|
|
396
|
+
|
|
397
|
+
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
|
|
398
|
+
// ith, ir, ib, iq1, iq2, iq3,
|
|
399
|
+
// size_k_row, size_v_row, current_block_size,
|
|
400
|
+
// (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
|
|
309
401
|
}
|
|
310
402
|
|
|
311
|
-
const
|
|
403
|
+
const uint32_t h = iq2; // head index
|
|
404
|
+
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
|
|
405
|
+
|
|
406
|
+
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
|
|
407
|
+
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
|
|
312
408
|
|
|
313
|
-
|
|
409
|
+
// Clear accumulator
|
|
410
|
+
hvx_splat_f32_a(spad_a, 0, DV);
|
|
411
|
+
float * VKQ32 = (float *) (spad_a + 0);
|
|
412
|
+
|
|
413
|
+
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
|
414
|
+
if (factx->is_q_fp32) {
|
|
415
|
+
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
|
|
419
|
+
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
|
|
314
420
|
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
|
315
421
|
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
|
316
422
|
|
|
@@ -319,156 +425,166 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|
|
319
425
|
uint8_t * v_base = dma_queue_pop(dma).dst; // V
|
|
320
426
|
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
|
|
321
427
|
|
|
428
|
+
// FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u",
|
|
429
|
+
// ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm,
|
|
430
|
+
// (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
|
|
431
|
+
|
|
322
432
|
// Inner loop processing the block from VTCM
|
|
323
433
|
uint32_t ic = 0;
|
|
324
434
|
|
|
325
|
-
// Process in blocks of 32 (VLEN_FP32)
|
|
326
|
-
|
|
435
|
+
// Process in sub-blocks of 32 (VLEN_FP32)
|
|
436
|
+
HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
|
|
437
|
+
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
|
438
|
+
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
|
327
439
|
// 1. Compute scores
|
|
328
|
-
|
|
329
|
-
for (int j = 0; j < VLEN_FP32; ++j) {
|
|
330
|
-
const uint32_t cur_ic = ic + j;
|
|
331
|
-
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
|
332
|
-
if (q->type == HTP_TYPE_F32) {
|
|
333
|
-
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
|
334
|
-
} else {
|
|
335
|
-
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
|
336
|
-
}
|
|
337
|
-
}
|
|
338
|
-
|
|
339
|
-
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
|
440
|
+
HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);
|
|
340
441
|
|
|
341
442
|
// 2. Softcap
|
|
342
|
-
if (logit_softcap != 0.0f) {
|
|
343
|
-
scores =
|
|
344
|
-
scores = Q6_Vqf32_vmpy_VsfVsf(scores,
|
|
443
|
+
if (factx->logit_softcap != 0.0f) {
|
|
444
|
+
scores = hvx_vec_tanh_f32(scores);
|
|
445
|
+
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
|
|
345
446
|
scores = Q6_Vsf_equals_Vqf32(scores);
|
|
346
447
|
}
|
|
347
448
|
|
|
348
449
|
// 3. Mask
|
|
349
450
|
if (mask) {
|
|
350
451
|
const __fp16 * mp = m_base + ic;
|
|
351
|
-
HVX_Vector
|
|
352
|
-
|
|
353
|
-
HVX_Vector
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
|
|
357
|
-
|
|
358
|
-
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
|
|
359
|
-
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
|
|
360
|
-
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
|
452
|
+
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
|
453
|
+
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
|
454
|
+
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
|
455
|
+
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
|
|
361
456
|
scores = Q6_Vsf_equals_Vqf32(scores);
|
|
362
457
|
}
|
|
363
458
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
float M_old = M;
|
|
369
|
-
float M_new = (m_block > M) ? m_block : M;
|
|
370
|
-
M = M_new;
|
|
459
|
+
sb_scores[iv] = scores;
|
|
460
|
+
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
|
|
461
|
+
}
|
|
371
462
|
|
|
372
|
-
|
|
463
|
+
{
|
|
464
|
+
// 4. Online Softmax Update
|
|
465
|
+
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
|
|
466
|
+
HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));
|
|
467
|
+
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
|
468
|
+
M_vec = M_new_vec;
|
|
373
469
|
|
|
374
|
-
|
|
375
|
-
S = S * ms;
|
|
470
|
+
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
|
376
471
|
|
|
377
|
-
HVX_Vector
|
|
378
|
-
|
|
379
|
-
|
|
472
|
+
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
|
473
|
+
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
|
474
|
+
HVX_Vector scores = sb_scores[iv];
|
|
475
|
+
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
|
|
476
|
+
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
|
380
477
|
|
|
381
|
-
|
|
382
|
-
float p_sum = hvx_vec_get_fp32(p_sum_vec);
|
|
383
|
-
S += p_sum;
|
|
478
|
+
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
|
384
479
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
480
|
+
// 5. Accumulate V
|
|
481
|
+
__fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
|
|
482
|
+
hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));
|
|
388
483
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
484
|
+
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
|
485
|
+
const uint32_t cur_ic = ic2 + j;
|
|
486
|
+
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
|
487
|
+
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
|
|
488
|
+
}
|
|
393
489
|
}
|
|
394
|
-
}
|
|
395
490
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
|
491
|
+
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
|
492
|
+
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
|
|
493
|
+
}
|
|
400
494
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
495
|
+
if (ic < current_block_size) {
|
|
496
|
+
// Sync scalars for leftover/next block if needed
|
|
497
|
+
float M = hvx_vec_get_f32(M_vec);
|
|
498
|
+
float S = hvx_vec_get_f32(S_vec);
|
|
499
|
+
|
|
500
|
+
// Leftover
|
|
501
|
+
for (; ic < current_block_size; ++ic) {
|
|
502
|
+
float s_val;
|
|
503
|
+
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
|
|
504
|
+
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
|
|
505
|
+
if (factx->logit_softcap != 0.0f) {
|
|
506
|
+
s_val = factx->logit_softcap * tanhf(s_val);
|
|
507
|
+
}
|
|
406
508
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
509
|
+
if (mask) {
|
|
510
|
+
const float m_val = m_base[ic];
|
|
511
|
+
s_val += slope * m_val;
|
|
512
|
+
}
|
|
410
513
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
s_val += slope * m_val;
|
|
414
|
-
}
|
|
514
|
+
const float Mold = M;
|
|
515
|
+
__fp16 vs = 1.0f;
|
|
415
516
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
517
|
+
if (s_val > M) {
|
|
518
|
+
M = s_val;
|
|
519
|
+
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
|
|
520
|
+
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
|
521
|
+
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
|
419
522
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
523
|
+
float ms = hvx_vec_get_f32(ms_vec);
|
|
524
|
+
S = S * ms + vs;
|
|
525
|
+
} else {
|
|
526
|
+
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
|
|
527
|
+
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
|
528
|
+
S += vs;
|
|
529
|
+
}
|
|
427
530
|
|
|
428
|
-
|
|
531
|
+
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
|
|
429
532
|
|
|
430
|
-
|
|
533
|
+
hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);
|
|
534
|
+
}
|
|
431
535
|
|
|
432
|
-
|
|
536
|
+
M_vec = hvx_vec_splat_f32(M);
|
|
537
|
+
S_vec = hvx_vec_splat_f32(S);
|
|
433
538
|
}
|
|
434
539
|
|
|
435
540
|
// Issue DMA for next+1 block (if exists)
|
|
436
|
-
if (ib + 2 < n_blocks) {
|
|
541
|
+
if (ib + 2 < factx->n_blocks) {
|
|
437
542
|
const uint32_t next_ib = ib + 2;
|
|
438
543
|
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
|
|
439
544
|
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
|
|
440
545
|
|
|
441
546
|
// K
|
|
442
547
|
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
|
443
|
-
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
|
|
548
|
+
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
|
|
444
549
|
|
|
445
550
|
// V
|
|
446
551
|
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
|
447
|
-
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
|
|
552
|
+
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
|
|
448
553
|
|
|
449
554
|
// Mask
|
|
450
555
|
if (mask) {
|
|
451
556
|
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
|
|
452
557
|
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
|
|
453
558
|
}
|
|
559
|
+
|
|
560
|
+
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
|
|
561
|
+
// ith, ir, next_ib, iq1, iq2, iq3,
|
|
562
|
+
// size_k_row, size_v_row, next_block_size,
|
|
563
|
+
// (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
|
|
454
564
|
}
|
|
455
565
|
}
|
|
456
566
|
|
|
457
567
|
// sinks
|
|
568
|
+
float M = hvx_vec_get_f32(M_vec);
|
|
569
|
+
float S = hvx_vec_get_f32(S_vec);
|
|
570
|
+
|
|
458
571
|
if (sinks) {
|
|
459
572
|
const float s = ((float *)((char *) sinks->data))[h];
|
|
460
573
|
|
|
461
|
-
float ms = 1.0f;
|
|
462
574
|
float vs = 1.0f;
|
|
463
575
|
|
|
464
576
|
if (s > M) {
|
|
465
|
-
|
|
466
|
-
|
|
577
|
+
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
|
|
578
|
+
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
|
579
|
+
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
|
580
|
+
|
|
581
|
+
float ms = hvx_vec_get_f32(ms_vec);
|
|
582
|
+
S = S * ms + vs;
|
|
467
583
|
} else {
|
|
468
|
-
|
|
584
|
+
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
|
|
585
|
+
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
|
586
|
+
S += vs;
|
|
469
587
|
}
|
|
470
|
-
|
|
471
|
-
S = S * ms + vs;
|
|
472
588
|
}
|
|
473
589
|
|
|
474
590
|
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
@@ -484,60 +600,91 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|
|
484
600
|
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
|
|
485
601
|
|
|
486
602
|
if (dst->type == HTP_TYPE_F32) {
|
|
487
|
-
|
|
603
|
+
hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
|
488
604
|
} else if (dst->type == HTP_TYPE_F16) {
|
|
489
|
-
|
|
605
|
+
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
|
490
606
|
}
|
|
491
607
|
}
|
|
492
608
|
}
|
|
493
609
|
|
|
494
|
-
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
|
|
495
|
-
struct htp_ops_context * octx = data;
|
|
496
|
-
flash_attn_ext_f16_thread(octx, i, n);
|
|
497
|
-
}
|
|
498
|
-
|
|
499
610
|
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
|
500
611
|
const struct htp_tensor * q = &octx->src0;
|
|
501
612
|
const struct htp_tensor * k = &octx->src1;
|
|
502
613
|
const struct htp_tensor * v = &octx->src2;
|
|
503
|
-
const struct htp_tensor * mask = (octx->src3.
|
|
504
|
-
struct htp_tensor * dst = &octx->dst;
|
|
614
|
+
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
|
615
|
+
const struct htp_tensor * dst = &octx->dst;
|
|
505
616
|
|
|
506
617
|
// Check support
|
|
507
|
-
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
|
|
508
|
-
k->type != HTP_TYPE_F16 ||
|
|
509
|
-
v->type != HTP_TYPE_F16) {
|
|
618
|
+
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
|
|
510
619
|
return HTP_STATUS_NO_SUPPORT;
|
|
511
620
|
}
|
|
512
621
|
|
|
513
|
-
|
|
514
|
-
octx
|
|
622
|
+
struct htp_fa_context factx;
|
|
623
|
+
factx.octx = octx;
|
|
515
624
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
625
|
+
factx.t_start = HAP_perf_get_qtimer_count();
|
|
626
|
+
|
|
627
|
+
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
|
628
|
+
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
|
|
629
|
+
|
|
630
|
+
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
|
631
|
+
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
|
632
|
+
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
|
633
|
+
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
|
520
634
|
|
|
521
635
|
if (mask) {
|
|
522
|
-
|
|
523
|
-
|
|
636
|
+
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
|
|
637
|
+
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
|
|
641
|
+
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
|
|
642
|
+
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
|
643
|
+
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
|
644
|
+
|
|
645
|
+
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
|
|
646
|
+
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
|
647
|
+
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
|
648
|
+
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
|
649
|
+
|
|
650
|
+
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
|
651
|
+
|
|
652
|
+
float scale = 1.0f;
|
|
653
|
+
float max_bias = 0.0f;
|
|
654
|
+
float logit_softcap = 0.0f;
|
|
655
|
+
|
|
656
|
+
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
|
657
|
+
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
|
658
|
+
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
|
659
|
+
|
|
660
|
+
if (logit_softcap != 0.0f) {
|
|
661
|
+
scale /= logit_softcap;
|
|
524
662
|
}
|
|
525
663
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
664
|
+
factx.scale = scale;
|
|
665
|
+
factx.max_bias = max_bias;
|
|
666
|
+
factx.logit_softcap = logit_softcap;
|
|
667
|
+
|
|
668
|
+
uint32_t n_head = q->ne[2];
|
|
669
|
+
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
|
670
|
+
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
|
|
671
|
+
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
|
|
672
|
+
|
|
673
|
+
// total rows in q
|
|
674
|
+
const uint32_t neq0 = q->ne[0];
|
|
675
|
+
const uint32_t neq1 = q->ne[1];
|
|
676
|
+
const uint32_t neq2 = q->ne[2];
|
|
677
|
+
const uint32_t neq3 = q->ne[3];
|
|
529
678
|
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
|
533
|
-
size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
|
679
|
+
factx.qrows = neq1*neq2*neq3;
|
|
680
|
+
factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads;
|
|
534
681
|
|
|
535
|
-
size_t size_vkq_acc =
|
|
682
|
+
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
|
536
683
|
|
|
537
684
|
octx->src0_spad.size_per_thread = size_q_block * 1;
|
|
538
|
-
octx->src1_spad.size_per_thread = size_k_block * 2;
|
|
539
|
-
octx->src2_spad.size_per_thread = size_v_block * 2;
|
|
540
|
-
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
|
|
685
|
+
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
|
|
686
|
+
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
|
|
687
|
+
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
|
|
541
688
|
octx->dst_spad.size_per_thread = size_vkq_acc;
|
|
542
689
|
|
|
543
690
|
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
@@ -559,7 +706,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
|
|
559
706
|
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
|
560
707
|
|
|
561
708
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
|
562
|
-
worker_pool_run_func(octx->ctx->worker_pool,
|
|
709
|
+
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
|
|
563
710
|
}
|
|
564
711
|
|
|
565
712
|
return HTP_STATUS_OK;
|