whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -3,14 +3,14 @@
|
|
|
3
3
|
#include "ggml-cpu.h"
|
|
4
4
|
#include "ggml-impl.h"
|
|
5
5
|
#include "binary-ops.h"
|
|
6
|
+
#include "simd-gemm.h"
|
|
6
7
|
#include "ggml.h"
|
|
7
8
|
#include "unary-ops.h"
|
|
8
9
|
#include "vec.h"
|
|
9
10
|
|
|
10
|
-
#include <cfloat>
|
|
11
11
|
#include <algorithm>
|
|
12
|
+
#include <cfloat>
|
|
12
13
|
#include <cmath>
|
|
13
|
-
#include <functional>
|
|
14
14
|
|
|
15
15
|
// ggml_compute_forward_dup
|
|
16
16
|
|
|
@@ -375,7 +375,7 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
375
375
|
const size_t rs = ne00 * type_size;
|
|
376
376
|
|
|
377
377
|
if (nb00 == type_size) {
|
|
378
|
-
// src0 is
|
|
378
|
+
// src0 is contiguous on first dimension, copy by rows
|
|
379
379
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
380
380
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
381
381
|
id += rs * ir0;
|
|
@@ -670,6 +670,7 @@ void ggml_compute_forward_add(
|
|
|
670
670
|
case GGML_TYPE_Q5_1:
|
|
671
671
|
case GGML_TYPE_Q8_0:
|
|
672
672
|
case GGML_TYPE_MXFP4:
|
|
673
|
+
case GGML_TYPE_NVFP4:
|
|
673
674
|
case GGML_TYPE_Q2_K:
|
|
674
675
|
case GGML_TYPE_Q3_K:
|
|
675
676
|
case GGML_TYPE_Q4_K:
|
|
@@ -1119,6 +1120,7 @@ void ggml_compute_forward_add1(
|
|
|
1119
1120
|
case GGML_TYPE_Q8_0:
|
|
1120
1121
|
case GGML_TYPE_Q8_1:
|
|
1121
1122
|
case GGML_TYPE_MXFP4:
|
|
1123
|
+
case GGML_TYPE_NVFP4:
|
|
1122
1124
|
case GGML_TYPE_Q2_K:
|
|
1123
1125
|
case GGML_TYPE_Q3_K:
|
|
1124
1126
|
case GGML_TYPE_Q4_K:
|
|
@@ -1247,6 +1249,7 @@ void ggml_compute_forward_acc(
|
|
|
1247
1249
|
case GGML_TYPE_Q8_0:
|
|
1248
1250
|
case GGML_TYPE_Q8_1:
|
|
1249
1251
|
case GGML_TYPE_MXFP4:
|
|
1252
|
+
case GGML_TYPE_NVFP4:
|
|
1250
1253
|
case GGML_TYPE_Q2_K:
|
|
1251
1254
|
case GGML_TYPE_Q3_K:
|
|
1252
1255
|
case GGML_TYPE_Q4_K:
|
|
@@ -1795,7 +1798,7 @@ void ggml_compute_forward_repeat(
|
|
|
1795
1798
|
{
|
|
1796
1799
|
ggml_compute_forward_repeat_f32(params, dst);
|
|
1797
1800
|
} break;
|
|
1798
|
-
// TODO: templateify the
|
|
1801
|
+
// TODO: templateify the implementation and support for I64
|
|
1799
1802
|
// ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
|
|
1800
1803
|
//case GGML_TYPE_I64:
|
|
1801
1804
|
// {
|
|
@@ -2097,10 +2100,14 @@ static void ggml_compute_forward_gelu_f32(
|
|
|
2097
2100
|
|
|
2098
2101
|
const ggml_tensor * src0 = dst->src[0];
|
|
2099
2102
|
|
|
2100
|
-
assert(
|
|
2101
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2103
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2102
2104
|
assert(ggml_are_same_shape(src0, dst));
|
|
2103
2105
|
|
|
2106
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2107
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2108
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2109
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2110
|
+
|
|
2104
2111
|
const int ith = params->ith;
|
|
2105
2112
|
const int nth = params->nth;
|
|
2106
2113
|
|
|
@@ -2114,19 +2121,23 @@ static void ggml_compute_forward_gelu_f32(
|
|
|
2114
2121
|
const int ir0 = dr*ith;
|
|
2115
2122
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2116
2123
|
|
|
2117
|
-
for (int
|
|
2124
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2125
|
+
const int i3 = ir/(ne02*ne01);
|
|
2126
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2127
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2128
|
+
|
|
2118
2129
|
ggml_vec_gelu_f32(nc,
|
|
2119
|
-
(float *) ((char *) dst->data + i1*
|
|
2120
|
-
(float *) ((char *) src0->data + i1*
|
|
2130
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2131
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2121
2132
|
|
|
2122
2133
|
#ifndef NDEBUG
|
|
2123
2134
|
for (int k = 0; k < nc; k++) {
|
|
2124
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2135
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2125
2136
|
GGML_UNUSED(x);
|
|
2126
2137
|
assert(!isnan(x));
|
|
2127
2138
|
assert(!isinf(x));
|
|
2128
2139
|
}
|
|
2129
|
-
#endif
|
|
2140
|
+
#endif // NDEBUG
|
|
2130
2141
|
}
|
|
2131
2142
|
}
|
|
2132
2143
|
|
|
@@ -2136,10 +2147,14 @@ static void ggml_compute_forward_gelu_f16(
|
|
|
2136
2147
|
|
|
2137
2148
|
const ggml_tensor * src0 = dst->src[0];
|
|
2138
2149
|
|
|
2139
|
-
assert(
|
|
2140
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2150
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2141
2151
|
assert(ggml_are_same_shape(src0, dst));
|
|
2142
2152
|
|
|
2153
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2154
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2155
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2156
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2157
|
+
|
|
2143
2158
|
const int ith = params->ith;
|
|
2144
2159
|
const int nth = params->nth;
|
|
2145
2160
|
|
|
@@ -2153,20 +2168,24 @@ static void ggml_compute_forward_gelu_f16(
|
|
|
2153
2168
|
const int ir0 = dr*ith;
|
|
2154
2169
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2155
2170
|
|
|
2156
|
-
for (int
|
|
2171
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2172
|
+
const int i3 = ir/(ne02*ne01);
|
|
2173
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2174
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2175
|
+
|
|
2157
2176
|
ggml_vec_gelu_f16(nc,
|
|
2158
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2159
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2177
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2178
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2160
2179
|
|
|
2161
2180
|
#ifndef NDEBUG
|
|
2162
2181
|
for (int k = 0; k < nc; k++) {
|
|
2163
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2182
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2164
2183
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2165
2184
|
GGML_UNUSED(v);
|
|
2166
2185
|
assert(!isnan(v));
|
|
2167
2186
|
assert(!isinf(v));
|
|
2168
2187
|
}
|
|
2169
|
-
#endif
|
|
2188
|
+
#endif // NDEBUG
|
|
2170
2189
|
}
|
|
2171
2190
|
}
|
|
2172
2191
|
|
|
@@ -2277,10 +2296,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
|
|
2277
2296
|
|
|
2278
2297
|
const ggml_tensor * src0 = dst->src[0];
|
|
2279
2298
|
|
|
2280
|
-
assert(
|
|
2281
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2299
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2282
2300
|
assert(ggml_are_same_shape(src0, dst));
|
|
2283
2301
|
|
|
2302
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2303
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2304
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2305
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2306
|
+
|
|
2284
2307
|
const int ith = params->ith;
|
|
2285
2308
|
const int nth = params->nth;
|
|
2286
2309
|
|
|
@@ -2294,19 +2317,23 @@ static void ggml_compute_forward_gelu_erf_f32(
|
|
|
2294
2317
|
const int ir0 = dr*ith;
|
|
2295
2318
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2296
2319
|
|
|
2297
|
-
for (int
|
|
2320
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2321
|
+
const int i3 = ir/(ne02*ne01);
|
|
2322
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2323
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2324
|
+
|
|
2298
2325
|
ggml_vec_gelu_erf_f32(nc,
|
|
2299
|
-
(float *) ((char *) dst->data + i1*
|
|
2300
|
-
(float *) ((char *) src0->data + i1*
|
|
2326
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2327
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2301
2328
|
|
|
2302
2329
|
#ifndef NDEBUG
|
|
2303
2330
|
for (int k = 0; k < nc; k++) {
|
|
2304
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2331
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2305
2332
|
GGML_UNUSED(x);
|
|
2306
2333
|
assert(!isnan(x));
|
|
2307
2334
|
assert(!isinf(x));
|
|
2308
2335
|
}
|
|
2309
|
-
#endif
|
|
2336
|
+
#endif // NDEBUG
|
|
2310
2337
|
}
|
|
2311
2338
|
}
|
|
2312
2339
|
|
|
@@ -2316,10 +2343,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|
|
2316
2343
|
|
|
2317
2344
|
const ggml_tensor * src0 = dst->src[0];
|
|
2318
2345
|
|
|
2319
|
-
assert(
|
|
2320
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2346
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2321
2347
|
assert(ggml_are_same_shape(src0, dst));
|
|
2322
2348
|
|
|
2349
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2350
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2351
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2352
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2353
|
+
|
|
2323
2354
|
const int ith = params->ith;
|
|
2324
2355
|
const int nth = params->nth;
|
|
2325
2356
|
|
|
@@ -2333,20 +2364,24 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|
|
2333
2364
|
const int ir0 = dr*ith;
|
|
2334
2365
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2335
2366
|
|
|
2336
|
-
for (int
|
|
2367
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2368
|
+
const int i3 = ir/(ne02*ne01);
|
|
2369
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2370
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2371
|
+
|
|
2337
2372
|
ggml_vec_gelu_erf_f16(nc,
|
|
2338
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2339
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2373
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2374
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2340
2375
|
|
|
2341
2376
|
#ifndef NDEBUG
|
|
2342
2377
|
for (int k = 0; k < nc; k++) {
|
|
2343
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2378
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2344
2379
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2345
2380
|
GGML_UNUSED(v);
|
|
2346
2381
|
assert(!isnan(v));
|
|
2347
2382
|
assert(!isinf(v));
|
|
2348
2383
|
}
|
|
2349
|
-
#endif
|
|
2384
|
+
#endif // NDEBUG
|
|
2350
2385
|
}
|
|
2351
2386
|
}
|
|
2352
2387
|
|
|
@@ -2380,10 +2415,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|
|
2380
2415
|
|
|
2381
2416
|
const ggml_tensor * src0 = dst->src[0];
|
|
2382
2417
|
|
|
2383
|
-
assert(
|
|
2384
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2418
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2385
2419
|
assert(ggml_are_same_shape(src0, dst));
|
|
2386
2420
|
|
|
2421
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2422
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2423
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2424
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2425
|
+
|
|
2387
2426
|
const int ith = params->ith;
|
|
2388
2427
|
const int nth = params->nth;
|
|
2389
2428
|
|
|
@@ -2397,19 +2436,23 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|
|
2397
2436
|
const int ir0 = dr*ith;
|
|
2398
2437
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2399
2438
|
|
|
2400
|
-
for (int
|
|
2439
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2440
|
+
const int i3 = ir/(ne02*ne01);
|
|
2441
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2442
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2443
|
+
|
|
2401
2444
|
ggml_vec_gelu_quick_f32(nc,
|
|
2402
|
-
(float *) ((char *) dst->data + i1*
|
|
2403
|
-
(float *) ((char *) src0->data + i1*
|
|
2445
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2446
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2404
2447
|
|
|
2405
2448
|
#ifndef NDEBUG
|
|
2406
2449
|
for (int k = 0; k < nc; k++) {
|
|
2407
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2450
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2408
2451
|
GGML_UNUSED(x);
|
|
2409
2452
|
assert(!isnan(x));
|
|
2410
2453
|
assert(!isinf(x));
|
|
2411
2454
|
}
|
|
2412
|
-
#endif
|
|
2455
|
+
#endif // NDEBUG
|
|
2413
2456
|
}
|
|
2414
2457
|
}
|
|
2415
2458
|
|
|
@@ -2419,10 +2462,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|
|
2419
2462
|
|
|
2420
2463
|
const ggml_tensor * src0 = dst->src[0];
|
|
2421
2464
|
|
|
2422
|
-
assert(
|
|
2423
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2465
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2424
2466
|
assert(ggml_are_same_shape(src0, dst));
|
|
2425
2467
|
|
|
2468
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2469
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2470
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2471
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2472
|
+
|
|
2426
2473
|
const int ith = params->ith;
|
|
2427
2474
|
const int nth = params->nth;
|
|
2428
2475
|
|
|
@@ -2436,20 +2483,24 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|
|
2436
2483
|
const int ir0 = dr*ith;
|
|
2437
2484
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2438
2485
|
|
|
2439
|
-
for (int
|
|
2486
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2487
|
+
const int i3 = ir/(ne02*ne01);
|
|
2488
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2489
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2490
|
+
|
|
2440
2491
|
ggml_vec_gelu_quick_f16(nc,
|
|
2441
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2442
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2492
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2493
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2443
2494
|
|
|
2444
2495
|
#ifndef NDEBUG
|
|
2445
2496
|
for (int k = 0; k < nc; k++) {
|
|
2446
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2497
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2447
2498
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2448
2499
|
GGML_UNUSED(v);
|
|
2449
2500
|
assert(!isnan(v));
|
|
2450
2501
|
assert(!isinf(v));
|
|
2451
2502
|
}
|
|
2452
|
-
#endif
|
|
2503
|
+
#endif // NDEBUG
|
|
2453
2504
|
}
|
|
2454
2505
|
}
|
|
2455
2506
|
|
|
@@ -2483,10 +2534,14 @@ static void ggml_compute_forward_silu_f32(
|
|
|
2483
2534
|
|
|
2484
2535
|
const ggml_tensor * src0 = dst->src[0];
|
|
2485
2536
|
|
|
2486
|
-
assert(
|
|
2487
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2537
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2488
2538
|
assert(ggml_are_same_shape(src0, dst));
|
|
2489
2539
|
|
|
2540
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2541
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2542
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2543
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2544
|
+
|
|
2490
2545
|
const int ith = params->ith;
|
|
2491
2546
|
const int nth = params->nth;
|
|
2492
2547
|
|
|
@@ -2500,19 +2555,23 @@ static void ggml_compute_forward_silu_f32(
|
|
|
2500
2555
|
const int ir0 = dr*ith;
|
|
2501
2556
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2502
2557
|
|
|
2503
|
-
for (int
|
|
2558
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2559
|
+
const int i3 = ir/(ne02*ne01);
|
|
2560
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2561
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2562
|
+
|
|
2504
2563
|
ggml_vec_silu_f32(nc,
|
|
2505
|
-
(float *) ((char *) dst->data + i1*
|
|
2506
|
-
(float *) ((char *) src0->data + i1*
|
|
2564
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2565
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2507
2566
|
|
|
2508
2567
|
#ifndef NDEBUG
|
|
2509
2568
|
for (int k = 0; k < nc; k++) {
|
|
2510
|
-
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
|
2569
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2511
2570
|
GGML_UNUSED(x);
|
|
2512
2571
|
assert(!isnan(x));
|
|
2513
2572
|
assert(!isinf(x));
|
|
2514
2573
|
}
|
|
2515
|
-
#endif
|
|
2574
|
+
#endif // NDEBUG
|
|
2516
2575
|
}
|
|
2517
2576
|
}
|
|
2518
2577
|
|
|
@@ -2522,10 +2581,14 @@ static void ggml_compute_forward_silu_f16(
|
|
|
2522
2581
|
|
|
2523
2582
|
const ggml_tensor * src0 = dst->src[0];
|
|
2524
2583
|
|
|
2525
|
-
assert(
|
|
2526
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2584
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2527
2585
|
assert(ggml_are_same_shape(src0, dst));
|
|
2528
2586
|
|
|
2587
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2588
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2589
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2590
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2591
|
+
|
|
2529
2592
|
const int ith = params->ith;
|
|
2530
2593
|
const int nth = params->nth;
|
|
2531
2594
|
|
|
@@ -2539,20 +2602,24 @@ static void ggml_compute_forward_silu_f16(
|
|
|
2539
2602
|
const int ir0 = dr*ith;
|
|
2540
2603
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2541
2604
|
|
|
2542
|
-
for (int
|
|
2605
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2606
|
+
const int i3 = ir/(ne02*ne01);
|
|
2607
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2608
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2609
|
+
|
|
2543
2610
|
ggml_vec_silu_f16(nc,
|
|
2544
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2545
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2611
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2612
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2546
2613
|
|
|
2547
2614
|
#ifndef NDEBUG
|
|
2548
2615
|
for (int k = 0; k < nc; k++) {
|
|
2549
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
|
2616
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2550
2617
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2551
2618
|
GGML_UNUSED(v);
|
|
2552
2619
|
assert(!isnan(v));
|
|
2553
2620
|
assert(!isinf(v));
|
|
2554
2621
|
}
|
|
2555
|
-
#endif
|
|
2622
|
+
#endif // NDEBUG
|
|
2556
2623
|
}
|
|
2557
2624
|
}
|
|
2558
2625
|
|
|
@@ -2702,7 +2769,7 @@ static void ggml_compute_forward_silu_back_f32(
|
|
|
2702
2769
|
assert(!isnan(x));
|
|
2703
2770
|
assert(!isinf(x));
|
|
2704
2771
|
}
|
|
2705
|
-
#endif
|
|
2772
|
+
#endif // NDEBUG
|
|
2706
2773
|
}
|
|
2707
2774
|
}
|
|
2708
2775
|
|
|
@@ -2738,7 +2805,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
2738
2805
|
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
2739
2806
|
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
2740
2807
|
|
|
2741
|
-
|
|
2808
|
+
#ifndef NDEBUG
|
|
2742
2809
|
for (int k = 0; k < nc; k++) {
|
|
2743
2810
|
const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2744
2811
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
@@ -2746,7 +2813,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
2746
2813
|
assert(!isnan(v));
|
|
2747
2814
|
assert(!isinf(v));
|
|
2748
2815
|
}
|
|
2749
|
-
|
|
2816
|
+
#endif // NDEBUG
|
|
2750
2817
|
}
|
|
2751
2818
|
}
|
|
2752
2819
|
|
|
@@ -2829,7 +2896,7 @@ static void ggml_compute_forward_reglu_f32(
|
|
|
2829
2896
|
assert(!isnan(x));
|
|
2830
2897
|
assert(!isinf(x));
|
|
2831
2898
|
}
|
|
2832
|
-
#endif
|
|
2899
|
+
#endif // NDEBUG
|
|
2833
2900
|
}
|
|
2834
2901
|
}
|
|
2835
2902
|
|
|
@@ -2889,7 +2956,7 @@ static void ggml_compute_forward_reglu_f16(
|
|
|
2889
2956
|
assert(!isnan(v));
|
|
2890
2957
|
assert(!isinf(v));
|
|
2891
2958
|
}
|
|
2892
|
-
#endif
|
|
2959
|
+
#endif // NDEBUG
|
|
2893
2960
|
}
|
|
2894
2961
|
}
|
|
2895
2962
|
|
|
@@ -2972,7 +3039,7 @@ static void ggml_compute_forward_geglu_f32(
|
|
|
2972
3039
|
assert(!isnan(x));
|
|
2973
3040
|
assert(!isinf(x));
|
|
2974
3041
|
}
|
|
2975
|
-
#endif
|
|
3042
|
+
#endif // NDEBUG
|
|
2976
3043
|
}
|
|
2977
3044
|
}
|
|
2978
3045
|
|
|
@@ -3032,7 +3099,7 @@ static void ggml_compute_forward_geglu_f16(
|
|
|
3032
3099
|
assert(!isnan(v));
|
|
3033
3100
|
assert(!isinf(v));
|
|
3034
3101
|
}
|
|
3035
|
-
#endif
|
|
3102
|
+
#endif // NDEBUG
|
|
3036
3103
|
}
|
|
3037
3104
|
}
|
|
3038
3105
|
|
|
@@ -3115,7 +3182,7 @@ static void ggml_compute_forward_swiglu_f32(
|
|
|
3115
3182
|
assert(!isnan(x));
|
|
3116
3183
|
assert(!isinf(x));
|
|
3117
3184
|
}
|
|
3118
|
-
#endif
|
|
3185
|
+
#endif // NDEBUG
|
|
3119
3186
|
}
|
|
3120
3187
|
}
|
|
3121
3188
|
|
|
@@ -3175,7 +3242,7 @@ static void ggml_compute_forward_swiglu_f16(
|
|
|
3175
3242
|
assert(!isnan(v));
|
|
3176
3243
|
assert(!isinf(v));
|
|
3177
3244
|
}
|
|
3178
|
-
#endif
|
|
3245
|
+
#endif // NDEBUG
|
|
3179
3246
|
}
|
|
3180
3247
|
}
|
|
3181
3248
|
|
|
@@ -3266,7 +3333,7 @@ static void ggml_compute_forward_swiglu_oai_f32(
|
|
|
3266
3333
|
assert(!isnan(x));
|
|
3267
3334
|
assert(!isinf(x));
|
|
3268
3335
|
}
|
|
3269
|
-
#endif
|
|
3336
|
+
#endif // NDEBUG
|
|
3270
3337
|
}
|
|
3271
3338
|
}
|
|
3272
3339
|
|
|
@@ -3345,7 +3412,7 @@ static void ggml_compute_forward_geglu_erf_f32(
|
|
|
3345
3412
|
assert(!isnan(x));
|
|
3346
3413
|
assert(!isinf(x));
|
|
3347
3414
|
}
|
|
3348
|
-
#endif
|
|
3415
|
+
#endif // NDEBUG
|
|
3349
3416
|
}
|
|
3350
3417
|
}
|
|
3351
3418
|
|
|
@@ -3405,7 +3472,7 @@ static void ggml_compute_forward_geglu_erf_f16(
|
|
|
3405
3472
|
assert(!isnan(v));
|
|
3406
3473
|
assert(!isinf(v));
|
|
3407
3474
|
}
|
|
3408
|
-
#endif
|
|
3475
|
+
#endif // NDEBUG
|
|
3409
3476
|
}
|
|
3410
3477
|
}
|
|
3411
3478
|
|
|
@@ -3488,7 +3555,7 @@ static void ggml_compute_forward_geglu_quick_f32(
|
|
|
3488
3555
|
assert(!isnan(x));
|
|
3489
3556
|
assert(!isinf(x));
|
|
3490
3557
|
}
|
|
3491
|
-
#endif
|
|
3558
|
+
#endif // NDEBUG
|
|
3492
3559
|
}
|
|
3493
3560
|
}
|
|
3494
3561
|
|
|
@@ -3548,7 +3615,7 @@ static void ggml_compute_forward_geglu_quick_f16(
|
|
|
3548
3615
|
assert(!isnan(v));
|
|
3549
3616
|
assert(!isinf(v));
|
|
3550
3617
|
}
|
|
3551
|
-
#endif
|
|
3618
|
+
#endif // NDEBUG
|
|
3552
3619
|
}
|
|
3553
3620
|
}
|
|
3554
3621
|
|
|
@@ -4270,6 +4337,7 @@ void ggml_compute_forward_out_prod(
|
|
|
4270
4337
|
case GGML_TYPE_Q5_1:
|
|
4271
4338
|
case GGML_TYPE_Q8_0:
|
|
4272
4339
|
case GGML_TYPE_MXFP4:
|
|
4340
|
+
case GGML_TYPE_NVFP4:
|
|
4273
4341
|
case GGML_TYPE_Q2_K:
|
|
4274
4342
|
case GGML_TYPE_Q3_K:
|
|
4275
4343
|
case GGML_TYPE_Q4_K:
|
|
@@ -4545,6 +4613,7 @@ void ggml_compute_forward_set(
|
|
|
4545
4613
|
case GGML_TYPE_Q8_0:
|
|
4546
4614
|
case GGML_TYPE_Q8_1:
|
|
4547
4615
|
case GGML_TYPE_MXFP4:
|
|
4616
|
+
case GGML_TYPE_NVFP4:
|
|
4548
4617
|
case GGML_TYPE_Q2_K:
|
|
4549
4618
|
case GGML_TYPE_Q3_K:
|
|
4550
4619
|
case GGML_TYPE_Q4_K:
|
|
@@ -4767,6 +4836,7 @@ void ggml_compute_forward_get_rows(
|
|
|
4767
4836
|
case GGML_TYPE_Q8_0:
|
|
4768
4837
|
case GGML_TYPE_Q8_1:
|
|
4769
4838
|
case GGML_TYPE_MXFP4:
|
|
4839
|
+
case GGML_TYPE_NVFP4:
|
|
4770
4840
|
case GGML_TYPE_Q2_K:
|
|
4771
4841
|
case GGML_TYPE_Q3_K:
|
|
4772
4842
|
case GGML_TYPE_Q4_K:
|
|
@@ -5239,7 +5309,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5239
5309
|
//printf("p[%d] = %f\n", i, p[i]);
|
|
5240
5310
|
assert(!isnan(wp[i]));
|
|
5241
5311
|
}
|
|
5242
|
-
#endif
|
|
5312
|
+
#endif // NDEBUG
|
|
5243
5313
|
|
|
5244
5314
|
float max = -INFINITY;
|
|
5245
5315
|
ggml_vec_max_f32(ne00, &max, wp);
|
|
@@ -5264,7 +5334,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5264
5334
|
assert(!isnan(dp[i]));
|
|
5265
5335
|
assert(!isinf(dp[i]));
|
|
5266
5336
|
}
|
|
5267
|
-
#endif
|
|
5337
|
+
#endif // NDEBUG
|
|
5268
5338
|
}
|
|
5269
5339
|
}
|
|
5270
5340
|
}
|
|
@@ -5338,7 +5408,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
|
|
|
5338
5408
|
assert(!isnan(dy[i]));
|
|
5339
5409
|
assert(!isnan(y[i]));
|
|
5340
5410
|
}
|
|
5341
|
-
#endif
|
|
5411
|
+
#endif // NDEBUG
|
|
5342
5412
|
// Jii = yi - yi*yi
|
|
5343
5413
|
// Jij = -yi*yj
|
|
5344
5414
|
// J = diag(y)-y.T*y
|
|
@@ -5371,7 +5441,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
|
|
|
5371
5441
|
assert(!isnan(dx[i]));
|
|
5372
5442
|
assert(!isinf(dx[i]));
|
|
5373
5443
|
}
|
|
5374
|
-
#endif
|
|
5444
|
+
#endif // NDEBUG
|
|
5375
5445
|
}
|
|
5376
5446
|
}
|
|
5377
5447
|
|
|
@@ -5491,6 +5561,7 @@ void ggml_compute_forward_clamp(
|
|
|
5491
5561
|
case GGML_TYPE_Q8_0:
|
|
5492
5562
|
case GGML_TYPE_Q8_1:
|
|
5493
5563
|
case GGML_TYPE_MXFP4:
|
|
5564
|
+
case GGML_TYPE_NVFP4:
|
|
5494
5565
|
case GGML_TYPE_Q2_K:
|
|
5495
5566
|
case GGML_TYPE_Q3_K:
|
|
5496
5567
|
case GGML_TYPE_Q4_K:
|
|
@@ -5739,28 +5810,33 @@ static void ggml_compute_forward_rope_flt(
|
|
|
5739
5810
|
|
|
5740
5811
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
5741
5812
|
|
|
5813
|
+
int64_t last_i2 = -1;
|
|
5814
|
+
|
|
5742
5815
|
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5743
5816
|
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5744
|
-
|
|
5745
|
-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5746
|
-
if (!mrope_used) {
|
|
5747
|
-
const int64_t p = pos[i2];
|
|
5748
|
-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5749
|
-
}
|
|
5750
|
-
else {
|
|
5751
|
-
const int64_t p_t = pos[i2];
|
|
5752
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
5753
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5754
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5755
|
-
ggml_mrope_cache_init(
|
|
5756
|
-
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5757
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5758
|
-
}
|
|
5759
|
-
|
|
5760
5817
|
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5761
|
-
if (ir++ < ir0) continue;
|
|
5818
|
+
if (ir++ < ir0) continue; // skip rows mapped to other threads
|
|
5762
5819
|
if (ir > ir1) break;
|
|
5763
5820
|
|
|
5821
|
+
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5822
|
+
if (last_i2 != i2) {
|
|
5823
|
+
if (!mrope_used) {
|
|
5824
|
+
const int64_t p = pos[i2];
|
|
5825
|
+
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5826
|
+
}
|
|
5827
|
+
else {
|
|
5828
|
+
const int64_t p_t = pos[i2];
|
|
5829
|
+
const int64_t p_h = pos[i2 + ne2];
|
|
5830
|
+
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5831
|
+
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5832
|
+
ggml_mrope_cache_init(
|
|
5833
|
+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5834
|
+
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5835
|
+
}
|
|
5836
|
+
|
|
5837
|
+
last_i2 = i2;
|
|
5838
|
+
}
|
|
5839
|
+
|
|
5764
5840
|
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
5765
5841
|
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
5766
5842
|
|
|
@@ -6129,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6129
6205
|
const ggml_tensor * src1 = dst->src[1];
|
|
6130
6206
|
|
|
6131
6207
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
6132
|
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6208
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
|
6133
6209
|
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
|
6134
6210
|
|
|
6135
6211
|
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
@@ -6160,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6160
6236
|
int ofs1 = is_2D ? nb12 : nb11;
|
|
6161
6237
|
|
|
6162
6238
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
|
6163
|
-
GGML_ASSERT(nb10 ==
|
|
6239
|
+
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
|
6164
6240
|
|
|
6165
6241
|
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
6166
6242
|
{
|
|
@@ -6173,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6173
6249
|
|
|
6174
6250
|
// micro kernel
|
|
6175
6251
|
ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
6176
|
-
const float * const
|
|
6252
|
+
const float * const src_data_f32 = src1->type == GGML_TYPE_F32
|
|
6253
|
+
? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
|
|
6254
|
+
: nullptr; // [IH, IW]
|
|
6255
|
+
const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
|
|
6256
|
+
? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
|
|
6257
|
+
: nullptr; // [IH, IW]
|
|
6177
6258
|
|
|
6178
6259
|
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
|
|
6179
6260
|
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
@@ -6183,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6183
6264
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
6184
6265
|
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
|
6185
6266
|
} else {
|
|
6186
|
-
|
|
6267
|
+
if (src_data_f32 != nullptr) {
|
|
6268
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
|
|
6269
|
+
} else {
|
|
6270
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
|
|
6271
|
+
}
|
|
6187
6272
|
}
|
|
6188
6273
|
}
|
|
6189
6274
|
}
|
|
@@ -7110,12 +7195,13 @@ void ggml_compute_forward_conv_2d_dw(
|
|
|
7110
7195
|
}
|
|
7111
7196
|
}
|
|
7112
7197
|
|
|
7113
|
-
//
|
|
7114
|
-
|
|
7115
|
-
static void ggml_compute_forward_pool_1d_sk_p0(
|
|
7198
|
+
// ggml_compute_forward_pool_1d_ksp
|
|
7199
|
+
static void ggml_compute_forward_pool_1d_ksp(
|
|
7116
7200
|
const ggml_compute_params * params,
|
|
7117
7201
|
const ggml_op_pool op,
|
|
7118
7202
|
const int k,
|
|
7203
|
+
const int s,
|
|
7204
|
+
const int p,
|
|
7119
7205
|
ggml_tensor * dst) {
|
|
7120
7206
|
|
|
7121
7207
|
const ggml_tensor * src = dst->src[0];
|
|
@@ -7126,39 +7212,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|
|
7126
7212
|
return;
|
|
7127
7213
|
}
|
|
7128
7214
|
|
|
7129
|
-
const
|
|
7130
|
-
const
|
|
7131
|
-
float * drow = (float *)dst->data;
|
|
7215
|
+
const int64_t IW = src->ne[0];
|
|
7216
|
+
const int64_t OW = dst->ne[0];
|
|
7132
7217
|
|
|
7133
|
-
const int64_t
|
|
7218
|
+
const int64_t nr = ggml_nrows(src);
|
|
7134
7219
|
|
|
7135
|
-
|
|
7136
|
-
const
|
|
7137
|
-
|
|
7138
|
-
|
|
7220
|
+
for (int64_t ir = 0; ir < nr; ++ir) {
|
|
7221
|
+
const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
|
|
7222
|
+
float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
|
|
7223
|
+
|
|
7224
|
+
for (int64_t ow = 0; ow < OW; ++ow) {
|
|
7225
|
+
float res = 0;
|
|
7139
7226
|
switch (op) {
|
|
7140
|
-
case GGML_OP_POOL_AVG:
|
|
7141
|
-
case GGML_OP_POOL_MAX:
|
|
7227
|
+
case GGML_OP_POOL_AVG: res = 0.0f; break;
|
|
7228
|
+
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
|
7142
7229
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7143
7230
|
}
|
|
7231
|
+
|
|
7232
|
+
int count = 0;
|
|
7233
|
+
const int base = (int) ow * s - p;
|
|
7234
|
+
|
|
7144
7235
|
for (int ki = 0; ki < k; ++ki) {
|
|
7145
|
-
const
|
|
7236
|
+
const int j = base + ki;
|
|
7237
|
+
if (j < 0 || j >= (int) IW) {
|
|
7238
|
+
continue;
|
|
7239
|
+
}
|
|
7240
|
+
|
|
7241
|
+
float v;
|
|
7242
|
+
if (src->type == GGML_TYPE_F32) {
|
|
7243
|
+
v = ((const float *) srow_bytes)[j];
|
|
7244
|
+
} else {
|
|
7245
|
+
v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
|
|
7246
|
+
}
|
|
7247
|
+
|
|
7146
7248
|
switch (op) {
|
|
7147
|
-
case GGML_OP_POOL_AVG:
|
|
7148
|
-
case GGML_OP_POOL_MAX:
|
|
7149
|
-
case GGML_OP_POOL_COUNT:
|
|
7249
|
+
case GGML_OP_POOL_AVG: res += v; break;
|
|
7250
|
+
case GGML_OP_POOL_MAX: res = std::max(v, res); break;
|
|
7251
|
+
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7150
7252
|
}
|
|
7151
|
-
|
|
7253
|
+
|
|
7254
|
+
++count;
|
|
7152
7255
|
}
|
|
7256
|
+
|
|
7153
7257
|
switch (op) {
|
|
7154
|
-
case GGML_OP_POOL_AVG:
|
|
7155
|
-
case GGML_OP_POOL_MAX:
|
|
7258
|
+
case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
|
|
7259
|
+
case GGML_OP_POOL_MAX: break;
|
|
7156
7260
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7157
7261
|
}
|
|
7158
|
-
}
|
|
7159
7262
|
|
|
7160
|
-
|
|
7161
|
-
|
|
7263
|
+
drow[ow] = res;
|
|
7264
|
+
}
|
|
7162
7265
|
}
|
|
7163
7266
|
}
|
|
7164
7267
|
|
|
@@ -7173,10 +7276,8 @@ void ggml_compute_forward_pool_1d(
|
|
|
7173
7276
|
const int k0 = opts[1];
|
|
7174
7277
|
const int s0 = opts[2];
|
|
7175
7278
|
const int p0 = opts[3];
|
|
7176
|
-
GGML_ASSERT(p0 == 0); // padding not supported
|
|
7177
|
-
GGML_ASSERT(k0 == s0); // only s = k supported
|
|
7178
7279
|
|
|
7179
|
-
|
|
7280
|
+
ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
|
|
7180
7281
|
}
|
|
7181
7282
|
|
|
7182
7283
|
// ggml_compute_forward_pool_2d
|
|
@@ -7194,6 +7295,7 @@ void ggml_compute_forward_pool_2d(
|
|
|
7194
7295
|
}
|
|
7195
7296
|
|
|
7196
7297
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
|
7298
|
+
|
|
7197
7299
|
ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
|
7198
7300
|
const int k0 = opts[1];
|
|
7199
7301
|
const int k1 = opts[2];
|
|
@@ -7217,11 +7319,13 @@ void ggml_compute_forward_pool_2d(
|
|
|
7217
7319
|
while (cdata < data_end) {
|
|
7218
7320
|
for (int oy = 0; oy < py; ++oy) {
|
|
7219
7321
|
float * const drow = dplane + oy * px;
|
|
7322
|
+
float * const out = drow;
|
|
7323
|
+
|
|
7220
7324
|
for (int ox = 0; ox < px; ++ox) {
|
|
7221
|
-
float
|
|
7325
|
+
float res = 0;
|
|
7222
7326
|
switch (op) {
|
|
7223
|
-
case GGML_OP_POOL_AVG:
|
|
7224
|
-
case GGML_OP_POOL_MAX:
|
|
7327
|
+
case GGML_OP_POOL_AVG: res = 0; break;
|
|
7328
|
+
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
|
7225
7329
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7226
7330
|
}
|
|
7227
7331
|
|
|
@@ -7229,24 +7333,32 @@ void ggml_compute_forward_pool_2d(
|
|
|
7229
7333
|
const int iy = offset1 + oy * s1;
|
|
7230
7334
|
|
|
7231
7335
|
for (int ky = 0; ky < k1; ++ky) {
|
|
7232
|
-
if (iy + ky < 0 || iy + ky >= src->ne[1])
|
|
7336
|
+
if (iy + ky < 0 || iy + ky >= src->ne[1]) {
|
|
7337
|
+
continue;
|
|
7338
|
+
}
|
|
7339
|
+
|
|
7233
7340
|
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
|
|
7234
7341
|
for (int kx = 0; kx < k0; ++kx) {
|
|
7235
7342
|
int j = ix + kx;
|
|
7236
|
-
if (j < 0 || j >= src->ne[0])
|
|
7343
|
+
if (j < 0 || j >= src->ne[0]) {
|
|
7344
|
+
continue;
|
|
7345
|
+
}
|
|
7346
|
+
|
|
7237
7347
|
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
|
7238
7348
|
switch (op) {
|
|
7239
|
-
case GGML_OP_POOL_AVG:
|
|
7240
|
-
case GGML_OP_POOL_MAX:
|
|
7349
|
+
case GGML_OP_POOL_AVG: res += srow_j; break;
|
|
7350
|
+
case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
|
|
7241
7351
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7242
7352
|
}
|
|
7243
7353
|
}
|
|
7244
7354
|
}
|
|
7245
7355
|
switch (op) {
|
|
7246
|
-
case GGML_OP_POOL_AVG:
|
|
7247
|
-
case GGML_OP_POOL_MAX:
|
|
7356
|
+
case GGML_OP_POOL_AVG: res /= ka; break;
|
|
7357
|
+
case GGML_OP_POOL_MAX: break;
|
|
7248
7358
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7249
7359
|
}
|
|
7360
|
+
|
|
7361
|
+
out[ox] = res;
|
|
7250
7362
|
}
|
|
7251
7363
|
}
|
|
7252
7364
|
|
|
@@ -7603,8 +7715,7 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7603
7715
|
|
|
7604
7716
|
const ggml_tensor * src0 = dst->src[0];
|
|
7605
7717
|
|
|
7606
|
-
|
|
7607
|
-
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
|
7718
|
+
assert(dst->nb[0] == sizeof(float));
|
|
7608
7719
|
|
|
7609
7720
|
const int ith = params->ith;
|
|
7610
7721
|
const int nth = params->nth;
|
|
@@ -8016,12 +8127,14 @@ void ggml_compute_forward_top_k(
|
|
|
8016
8127
|
}
|
|
8017
8128
|
}
|
|
8018
8129
|
|
|
8019
|
-
// ggml_compute_forward_flash_attn_ext
|
|
8020
|
-
|
|
8021
8130
|
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
8022
8131
|
const ggml_compute_params * params,
|
|
8023
8132
|
ggml_tensor * dst,
|
|
8024
|
-
int ir0, int ir1
|
|
8133
|
+
int ir0, int ir1,
|
|
8134
|
+
int64_t ic_start, int64_t ic_end,
|
|
8135
|
+
float * partials, int64_t partial_stride) {
|
|
8136
|
+
|
|
8137
|
+
const bool write_partials = (partials != nullptr);
|
|
8025
8138
|
const ggml_tensor * q = dst->src[0];
|
|
8026
8139
|
const ggml_tensor * k = dst->src[1];
|
|
8027
8140
|
const ggml_tensor * v = dst->src[2];
|
|
@@ -8098,7 +8211,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8098
8211
|
|
|
8099
8212
|
int ith = params->ith;
|
|
8100
8213
|
|
|
8101
|
-
// loop over n_batch and n_head
|
|
8102
8214
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
8103
8215
|
// q indices
|
|
8104
8216
|
const int iq3 = ir/(neq2*neq1);
|
|
@@ -8138,7 +8250,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8138
8250
|
// online softmax / attention
|
|
8139
8251
|
// loop over n_kv and n_head_kv
|
|
8140
8252
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
|
8141
|
-
|
|
8253
|
+
|
|
8254
|
+
for (int64_t ic = ic_start; ic < ic_end; ++ic) {
|
|
8142
8255
|
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
|
8143
8256
|
if (mv == -INFINITY) {
|
|
8144
8257
|
continue;
|
|
@@ -8211,8 +8324,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8211
8324
|
}
|
|
8212
8325
|
}
|
|
8213
8326
|
|
|
8214
|
-
// sinks
|
|
8215
|
-
if (sinks) {
|
|
8327
|
+
// sinks - apply only on the first kv-chunk
|
|
8328
|
+
if (sinks && ic_start == 0) {
|
|
8216
8329
|
const float s = ((float *)((char *) sinks->data))[h];
|
|
8217
8330
|
|
|
8218
8331
|
float ms = 1.0f;
|
|
@@ -8220,6 +8333,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8220
8333
|
|
|
8221
8334
|
if (s > M) {
|
|
8222
8335
|
ms = expf(M - s);
|
|
8336
|
+
M = s;
|
|
8223
8337
|
ggml_vec_scale_f32(DV, VKQ32, ms);
|
|
8224
8338
|
} else {
|
|
8225
8339
|
vs = expf(s - M);
|
|
@@ -8228,20 +8342,386 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
|
8228
8342
|
S = S*ms + vs;
|
|
8229
8343
|
}
|
|
8230
8344
|
|
|
8231
|
-
|
|
8232
|
-
|
|
8233
|
-
|
|
8345
|
+
if (write_partials) {
|
|
8346
|
+
// Write M, S, VKQ to partials for later reduction
|
|
8347
|
+
// partials layout: [M, S, VKQ[DV]] per query head
|
|
8348
|
+
float * partial = partials + ir * partial_stride;
|
|
8349
|
+
partial[0] = M;
|
|
8350
|
+
partial[1] = S;
|
|
8351
|
+
memcpy(partial + 2, VKQ32, DV * sizeof(float));
|
|
8352
|
+
} else {
|
|
8353
|
+
// V /= S
|
|
8354
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
8355
|
+
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
8356
|
+
|
|
8357
|
+
// dst indices
|
|
8358
|
+
const int i1 = iq1;
|
|
8359
|
+
const int i2 = iq2;
|
|
8360
|
+
const int i3 = iq3;
|
|
8234
8361
|
|
|
8235
|
-
|
|
8236
|
-
|
|
8237
|
-
|
|
8238
|
-
|
|
8362
|
+
// permute(0, 2, 1, 3)
|
|
8363
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
|
8364
|
+
}
|
|
8365
|
+
}
|
|
8366
|
+
}
|
|
8367
|
+
|
|
8368
|
+
static void ggml_compute_forward_flash_attn_ext_tiled(
|
|
8369
|
+
const ggml_compute_params * params,
|
|
8370
|
+
ggml_tensor * dst,
|
|
8371
|
+
int ir0, int ir1) {
|
|
8372
|
+
const ggml_tensor * q = dst->src[0];
|
|
8373
|
+
const ggml_tensor * k = dst->src[1];
|
|
8374
|
+
const ggml_tensor * v = dst->src[2];
|
|
8375
|
+
const ggml_tensor * mask = dst->src[3];
|
|
8376
|
+
const ggml_tensor * sinks = dst->src[4];
|
|
8377
|
+
|
|
8378
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8379
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8380
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8381
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8382
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8383
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8384
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8385
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8386
|
+
|
|
8387
|
+
const int64_t DK = nek0;
|
|
8388
|
+
const int64_t DV = nev0;
|
|
8389
|
+
const int64_t N = neq1;
|
|
8390
|
+
|
|
8391
|
+
GGML_ASSERT(ne0 == DV);
|
|
8392
|
+
GGML_ASSERT(ne2 == N);
|
|
8393
|
+
|
|
8394
|
+
// input tensor rows must be contiguous
|
|
8395
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
8396
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
8397
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
8398
|
+
|
|
8399
|
+
GGML_ASSERT(neq0 == DK);
|
|
8400
|
+
GGML_ASSERT(nek0 == DK);
|
|
8401
|
+
GGML_ASSERT(nev0 == DV);
|
|
8402
|
+
|
|
8403
|
+
GGML_ASSERT(neq1 == N);
|
|
8404
|
+
|
|
8405
|
+
// dst cannot be transposed or permuted
|
|
8406
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8407
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
8408
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
8409
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
8410
|
+
|
|
8411
|
+
GGML_ASSERT(k->type == v->type);
|
|
8412
|
+
const ggml_type kv_type = k->type;
|
|
8413
|
+
|
|
8414
|
+
|
|
8415
|
+
// broadcast factors
|
|
8416
|
+
const int64_t rk2 = neq2/nek2;
|
|
8417
|
+
const int64_t rk3 = neq3/nek3;
|
|
8418
|
+
|
|
8419
|
+
const int64_t rv2 = neq2/nev2;
|
|
8420
|
+
const int64_t rv3 = neq3/nev3;
|
|
8421
|
+
|
|
8422
|
+
float scale = 1.0f;
|
|
8423
|
+
float max_bias = 0.0f;
|
|
8424
|
+
float logit_softcap = 0.0f;
|
|
8425
|
+
|
|
8426
|
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
8427
|
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
8428
|
+
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
|
8429
|
+
|
|
8430
|
+
if (logit_softcap != 0) {
|
|
8431
|
+
scale /= logit_softcap;
|
|
8432
|
+
}
|
|
8433
|
+
|
|
8434
|
+
const uint32_t n_head = neq2;
|
|
8435
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
|
8436
|
+
|
|
8437
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
8438
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
8439
|
+
|
|
8440
|
+
int ith = params->ith;
|
|
8441
|
+
|
|
8442
|
+
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
8443
|
+
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
|
8444
|
+
|
|
8445
|
+
int ir = ir0;
|
|
8446
|
+
while (ir < ir1) {
|
|
8447
|
+
// q indices for the start of this tile
|
|
8448
|
+
const int iq3 = ir/(neq2*neq1);
|
|
8449
|
+
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
|
8450
|
+
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
|
8451
|
+
|
|
8452
|
+
// Number of valid rows in this tile:
|
|
8453
|
+
// - limited by tile size (Q_TILE_SZ)
|
|
8454
|
+
// - limited by chunk boundary (ir1 - ir)
|
|
8455
|
+
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
|
|
8456
|
+
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
|
|
8457
|
+
GGML_ASSERT(tile_rows > 0);
|
|
8458
|
+
|
|
8459
|
+
const uint32_t h = iq2; // head index
|
|
8460
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
8461
|
+
|
|
8462
|
+
float S[Q_TILE_SZ];
|
|
8463
|
+
float M[Q_TILE_SZ];
|
|
8464
|
+
|
|
8465
|
+
for (int i = 0 ; i < Q_TILE_SZ; ++i) {
|
|
8466
|
+
S[i] = 0.;
|
|
8467
|
+
M[i] = -INFINITY;
|
|
8468
|
+
}
|
|
8239
8469
|
|
|
8240
|
-
//
|
|
8241
|
-
//
|
|
8470
|
+
// Per-thread scratch layout:
|
|
8471
|
+
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
|
|
8472
|
+
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
|
8473
|
+
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
|
8474
|
+
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
|
8475
|
+
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
|
|
8476
|
+
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
|
|
8477
|
+
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
|
|
8242
8478
|
|
|
8243
|
-
|
|
8244
|
-
|
|
8479
|
+
void * Q_q = base;
|
|
8480
|
+
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
|
8481
|
+
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
|
8482
|
+
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
|
8483
|
+
float * V32 = VKQ32 + Q_TILE_SZ * DV;
|
|
8484
|
+
float * K_f32 = V32 + KV_TILE_SZ * DV;
|
|
8485
|
+
|
|
8486
|
+
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
|
8487
|
+
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
|
8488
|
+
|
|
8489
|
+
// k indices
|
|
8490
|
+
const int ik3 = iq3 / rk3;
|
|
8491
|
+
const int ik2 = iq2 / rk2;
|
|
8492
|
+
|
|
8493
|
+
// v indices
|
|
8494
|
+
const int iv3 = iq3 / rv3;
|
|
8495
|
+
const int iv2 = iq2 / rv2;
|
|
8496
|
+
|
|
8497
|
+
{
|
|
8498
|
+
float * Q_f32 = (float *)Q_q;
|
|
8499
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8500
|
+
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
|
8501
|
+
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
|
|
8502
|
+
}
|
|
8503
|
+
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
|
8504
|
+
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
|
|
8505
|
+
}
|
|
8506
|
+
}
|
|
8507
|
+
|
|
8508
|
+
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
|
|
8509
|
+
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
|
|
8510
|
+
|
|
8511
|
+
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
|
8512
|
+
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
|
|
8513
|
+
|
|
8514
|
+
// skip the tile entirely if all the masks are -inf
|
|
8515
|
+
if (mask) {
|
|
8516
|
+
bool can_skip = true;
|
|
8517
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8518
|
+
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
|
8519
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8520
|
+
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
|
8521
|
+
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
|
8522
|
+
can_skip = false;
|
|
8523
|
+
}
|
|
8524
|
+
}
|
|
8525
|
+
// Pad remaining mask entries with -inf
|
|
8526
|
+
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
|
8527
|
+
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
|
|
8528
|
+
}
|
|
8529
|
+
}
|
|
8530
|
+
|
|
8531
|
+
if (can_skip) {
|
|
8532
|
+
continue;
|
|
8533
|
+
}
|
|
8534
|
+
}
|
|
8535
|
+
|
|
8536
|
+
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
|
|
8537
|
+
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
|
|
8538
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8539
|
+
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
|
|
8540
|
+
if (kv_type == GGML_TYPE_F16) {
|
|
8541
|
+
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
|
|
8542
|
+
for (int64_t dk = 0; dk < DK; dk++) {
|
|
8543
|
+
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
|
|
8544
|
+
}
|
|
8545
|
+
} else {
|
|
8546
|
+
const float * k_f32_src = (const float *)k_data;
|
|
8547
|
+
for (int64_t dk = 0; dk < DK; dk++) {
|
|
8548
|
+
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
|
|
8549
|
+
}
|
|
8550
|
+
}
|
|
8551
|
+
}
|
|
8552
|
+
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
|
8553
|
+
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
|
|
8554
|
+
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
|
|
8555
|
+
|
|
8556
|
+
// Set padded KQ entries to -inf so softmax gives them zero weight
|
|
8557
|
+
if (kv_tile < KV_TILE_SZ) {
|
|
8558
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8559
|
+
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
|
8560
|
+
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
|
|
8561
|
+
}
|
|
8562
|
+
}
|
|
8563
|
+
}
|
|
8564
|
+
|
|
8565
|
+
if (logit_softcap != 0.0f) {
|
|
8566
|
+
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
|
|
8567
|
+
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
|
|
8568
|
+
}
|
|
8569
|
+
|
|
8570
|
+
if (mask) {
|
|
8571
|
+
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
|
|
8572
|
+
}
|
|
8573
|
+
|
|
8574
|
+
bool skip[Q_TILE_SZ] = {};
|
|
8575
|
+
|
|
8576
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8577
|
+
float * kq_row = KQ + tq * KV_TILE_SZ;
|
|
8578
|
+
|
|
8579
|
+
float tile_max;
|
|
8580
|
+
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
|
|
8581
|
+
|
|
8582
|
+
if (tile_max == -INFINITY) {
|
|
8583
|
+
skip[tq] = true;
|
|
8584
|
+
continue;
|
|
8585
|
+
}
|
|
8586
|
+
|
|
8587
|
+
const float Mold = M[tq];
|
|
8588
|
+
const float Mnew = fmaxf(Mold, tile_max);
|
|
8589
|
+
|
|
8590
|
+
if (Mnew > Mold) {
|
|
8591
|
+
const float ms = expf(Mold - Mnew);
|
|
8592
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
|
8593
|
+
S[tq] *= ms;
|
|
8594
|
+
}
|
|
8595
|
+
M[tq] = Mnew;
|
|
8596
|
+
|
|
8597
|
+
|
|
8598
|
+
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
|
8599
|
+
}
|
|
8600
|
+
|
|
8601
|
+
// V accumulation: VKQ32 += softmax(KQ) * V
|
|
8602
|
+
// Pack V tile to contiguous F32, zero-padded
|
|
8603
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8604
|
+
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
|
|
8605
|
+
if (kv_type == GGML_TYPE_F16) {
|
|
8606
|
+
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
|
|
8607
|
+
} else {
|
|
8608
|
+
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
|
|
8609
|
+
}
|
|
8610
|
+
}
|
|
8611
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8612
|
+
if (skip[tq]) {
|
|
8613
|
+
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
|
|
8614
|
+
}
|
|
8615
|
+
}
|
|
8616
|
+
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
|
|
8617
|
+
}
|
|
8618
|
+
|
|
8619
|
+
// sinks (apply only to valid rows in the tile)
|
|
8620
|
+
if (sinks) {
|
|
8621
|
+
const float s = ((float *)((char *) sinks->data))[h];
|
|
8622
|
+
|
|
8623
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8624
|
+
float ms = 1.0f;
|
|
8625
|
+
float vs = 1.0f;
|
|
8626
|
+
|
|
8627
|
+
if (s > M[tq]) {
|
|
8628
|
+
ms = expf(M[tq] - s);
|
|
8629
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
|
8630
|
+
} else {
|
|
8631
|
+
vs = expf(s - M[tq]);
|
|
8632
|
+
}
|
|
8633
|
+
|
|
8634
|
+
S[tq] = S[tq] * ms + vs;
|
|
8635
|
+
}
|
|
8636
|
+
}
|
|
8637
|
+
|
|
8638
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8639
|
+
// V /= S
|
|
8640
|
+
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
|
|
8641
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
|
|
8642
|
+
|
|
8643
|
+
// dst indices
|
|
8644
|
+
const int i1 = iq1 + tq;
|
|
8645
|
+
const int i2 = iq2;
|
|
8646
|
+
const int i3 = iq3;
|
|
8647
|
+
|
|
8648
|
+
// permute(0, 2, 1, 3)
|
|
8649
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
|
|
8650
|
+
}
|
|
8651
|
+
|
|
8652
|
+
ir += tile_rows;
|
|
8653
|
+
}
|
|
8654
|
+
}
|
|
8655
|
+
|
|
8656
|
+
// Reduction function: combines partial results across KV chunks
|
|
8657
|
+
// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
|
|
8658
|
+
static void ggml_flash_attn_ext_reduce_partials(
|
|
8659
|
+
const ggml_compute_params * params,
|
|
8660
|
+
ggml_tensor * dst,
|
|
8661
|
+
const int64_t n_chunks,
|
|
8662
|
+
const int64_t chunk_size) {
|
|
8663
|
+
|
|
8664
|
+
const ggml_tensor * q = dst->src[0];
|
|
8665
|
+
const ggml_tensor * k = dst->src[1];
|
|
8666
|
+
const ggml_tensor * v = dst->src[2];
|
|
8667
|
+
|
|
8668
|
+
const int64_t DK = k->ne[0];
|
|
8669
|
+
const int64_t DV = v->ne[0];
|
|
8670
|
+
const int64_t nek1 = k->ne[1];
|
|
8671
|
+
const int64_t n_q_heads = q->ne[2];
|
|
8672
|
+
|
|
8673
|
+
const int ith = params->ith;
|
|
8674
|
+
const int nth = params->nth;
|
|
8675
|
+
|
|
8676
|
+
const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
|
|
8677
|
+
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
|
|
8678
|
+
|
|
8679
|
+
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
|
|
8680
|
+
const int64_t partial_size = 2 + DV;
|
|
8681
|
+
const float * partials_base = (const float *) params->wdata + partials_offset;
|
|
8682
|
+
|
|
8683
|
+
// Output layout
|
|
8684
|
+
const int64_t ne1 = dst->ne[1];
|
|
8685
|
+
const int64_t ne2 = dst->ne[2];
|
|
8686
|
+
const size_t nb1 = dst->nb[1];
|
|
8687
|
+
|
|
8688
|
+
// Each thread reduces a subset of query heads
|
|
8689
|
+
for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
|
|
8690
|
+
float M_final = -INFINITY;
|
|
8691
|
+
float S_final = 0.0f;
|
|
8692
|
+
float * VKQ_final = thread_wdata;
|
|
8693
|
+
memset(VKQ_final, 0, DV * sizeof(float));
|
|
8694
|
+
|
|
8695
|
+
// Combine partials from all chunks
|
|
8696
|
+
for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
|
|
8697
|
+
const int64_t ic_start = chunk_idx * chunk_size;
|
|
8698
|
+
if (ic_start >= nek1) continue;
|
|
8699
|
+
|
|
8700
|
+
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
|
|
8701
|
+
const float M_chunk = partial[0];
|
|
8702
|
+
const float S_chunk = partial[1];
|
|
8703
|
+
const float * VKQ_chunk = partial + 2;
|
|
8704
|
+
|
|
8705
|
+
if (S_chunk == 0.0f) continue;
|
|
8706
|
+
|
|
8707
|
+
const float M_new = fmaxf(M_final, M_chunk);
|
|
8708
|
+
const float scale_old = expf(M_final - M_new);
|
|
8709
|
+
const float scale_new = expf(M_chunk - M_new);
|
|
8710
|
+
|
|
8711
|
+
for (int64_t d = 0; d < DV; ++d) {
|
|
8712
|
+
VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
|
|
8713
|
+
}
|
|
8714
|
+
S_final = S_final * scale_old + S_chunk * scale_new;
|
|
8715
|
+
M_final = M_new;
|
|
8716
|
+
}
|
|
8717
|
+
|
|
8718
|
+
// Normalize and write to output
|
|
8719
|
+
if (S_final != 0.0f) {
|
|
8720
|
+
const float S_inv = 1.0f / S_final;
|
|
8721
|
+
ggml_vec_scale_f32(DV, VKQ_final, S_inv);
|
|
8722
|
+
}
|
|
8723
|
+
// iq1=0, iq3=0 for decode
|
|
8724
|
+
memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
|
|
8245
8725
|
}
|
|
8246
8726
|
}
|
|
8247
8727
|
|
|
@@ -8266,6 +8746,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8266
8746
|
const int64_t DV = nev0;
|
|
8267
8747
|
const int64_t N = neq1;
|
|
8268
8748
|
|
|
8749
|
+
|
|
8269
8750
|
GGML_ASSERT(ne0 == DV);
|
|
8270
8751
|
GGML_ASSERT(ne2 == N);
|
|
8271
8752
|
|
|
@@ -8286,47 +8767,92 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8286
8767
|
GGML_ASSERT(nb1 <= nb2);
|
|
8287
8768
|
GGML_ASSERT(nb2 <= nb3);
|
|
8288
8769
|
|
|
8289
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
|
8290
|
-
|
|
8291
|
-
// total rows in q
|
|
8292
|
-
const int64_t nr = neq1*neq2*neq3;
|
|
8293
|
-
|
|
8294
|
-
// rows per thread
|
|
8295
8770
|
const int ith = params->ith;
|
|
8296
8771
|
const int nth = params->nth;
|
|
8297
8772
|
|
|
8298
|
-
//
|
|
8299
|
-
const bool
|
|
8773
|
+
// When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
|
|
8774
|
+
const bool use_ref = params->use_ref;
|
|
8300
8775
|
|
|
8301
|
-
|
|
8302
|
-
|
|
8303
|
-
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8304
|
-
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8776
|
+
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
|
|
8777
|
+
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
|
|
8305
8778
|
|
|
8306
|
-
if (
|
|
8307
|
-
|
|
8308
|
-
}
|
|
8779
|
+
if (use_split_kv_path) {
|
|
8780
|
+
const int64_t chunk_size = (nek1 + nth - 1) / nth;
|
|
8309
8781
|
|
|
8310
|
-
|
|
8311
|
-
|
|
8312
|
-
|
|
8313
|
-
}
|
|
8782
|
+
// Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
|
|
8783
|
+
const int64_t partial_size = 2 + DV;
|
|
8784
|
+
float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
|
|
8314
8785
|
|
|
8315
|
-
|
|
8786
|
+
const int64_t ic_start = ith * chunk_size;
|
|
8787
|
+
const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
|
|
8316
8788
|
|
|
8317
|
-
|
|
8318
|
-
|
|
8789
|
+
const int64_t partial_stride = nth * partial_size;
|
|
8790
|
+
float * chunk_partials = partials_base + ith * partial_size;
|
|
8319
8791
|
|
|
8320
|
-
|
|
8321
|
-
|
|
8792
|
+
if (ic_start < nek1) {
|
|
8793
|
+
for (int64_t q_head = 0; q_head < neq2; q_head++) {
|
|
8794
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
8795
|
+
params, dst, q_head, q_head + 1, ic_start, ic_end,
|
|
8796
|
+
chunk_partials, partial_stride);
|
|
8797
|
+
}
|
|
8798
|
+
} else {
|
|
8799
|
+
for (int64_t q_head = 0; q_head < neq2; q_head++) {
|
|
8800
|
+
float * q_partials = chunk_partials + q_head * partial_stride;
|
|
8801
|
+
q_partials[0] = -INFINITY; // M
|
|
8802
|
+
q_partials[1] = 0.0f; // S
|
|
8803
|
+
}
|
|
8804
|
+
}
|
|
8322
8805
|
|
|
8323
|
-
|
|
8324
|
-
|
|
8325
|
-
|
|
8806
|
+
ggml_barrier(params->threadpool);
|
|
8807
|
+
ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
|
|
8808
|
+
} else {
|
|
8326
8809
|
|
|
8327
|
-
|
|
8810
|
+
// total rows in q
|
|
8811
|
+
const int64_t nr = neq1*neq2*neq3;
|
|
8328
8812
|
|
|
8329
|
-
|
|
8813
|
+
// disable for NUMA
|
|
8814
|
+
const bool disable_chunking = ggml_is_numa();
|
|
8815
|
+
|
|
8816
|
+
// 4x chunks per thread
|
|
8817
|
+
int nth_scaled = nth * 4;
|
|
8818
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8819
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8820
|
+
|
|
8821
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
8822
|
+
nchunk = nth;
|
|
8823
|
+
}
|
|
8824
|
+
|
|
8825
|
+
if (ith == 0) {
|
|
8826
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
8827
|
+
}
|
|
8828
|
+
|
|
8829
|
+
ggml_barrier(params->threadpool);
|
|
8830
|
+
|
|
8831
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
8832
|
+
|
|
8833
|
+
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
8834
|
+
bool use_tiled = !use_ref &&
|
|
8835
|
+
(q->type == GGML_TYPE_F32 &&
|
|
8836
|
+
kv_is_f32_or_f16 &&
|
|
8837
|
+
k->type == v->type &&
|
|
8838
|
+
neq1 >= Q_TILE_SZ);
|
|
8839
|
+
#ifdef GGML_SIMD
|
|
8840
|
+
use_tiled &= (DV % GGML_F32_EPR == 0);
|
|
8841
|
+
#endif
|
|
8842
|
+
int current_chunk = ith;
|
|
8843
|
+
|
|
8844
|
+
while (current_chunk < nchunk) {
|
|
8845
|
+
const int64_t ir0 = dr * current_chunk;
|
|
8846
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
8847
|
+
|
|
8848
|
+
if (use_tiled) {
|
|
8849
|
+
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
|
8850
|
+
} else {
|
|
8851
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
|
|
8852
|
+
}
|
|
8853
|
+
|
|
8854
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
8855
|
+
}
|
|
8330
8856
|
}
|
|
8331
8857
|
}
|
|
8332
8858
|
|
|
@@ -9107,7 +9633,7 @@ void ggml_compute_forward_win_unpart(
|
|
|
9107
9633
|
}
|
|
9108
9634
|
}
|
|
9109
9635
|
|
|
9110
|
-
//
|
|
9636
|
+
//ggml_compute_forward_unary
|
|
9111
9637
|
|
|
9112
9638
|
void ggml_compute_forward_unary(
|
|
9113
9639
|
const ggml_compute_params * params,
|
|
@@ -9870,6 +10396,195 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
|
|
|
9870
10396
|
}
|
|
9871
10397
|
}
|
|
9872
10398
|
|
|
10399
|
+
// ggml_compute_forward_gated_delta_net
|
|
10400
|
+
static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|
10401
|
+
const ggml_compute_params * params,
|
|
10402
|
+
ggml_tensor * dst,
|
|
10403
|
+
int64_t ir0,
|
|
10404
|
+
int64_t ir1) {
|
|
10405
|
+
|
|
10406
|
+
ggml_tensor * src_q = dst->src[0];
|
|
10407
|
+
ggml_tensor * src_k = dst->src[1];
|
|
10408
|
+
ggml_tensor * src_v = dst->src[2];
|
|
10409
|
+
ggml_tensor * src_g = dst->src[3];
|
|
10410
|
+
ggml_tensor * src_beta = dst->src[4];
|
|
10411
|
+
ggml_tensor * src_state = dst->src[5];
|
|
10412
|
+
|
|
10413
|
+
const int64_t S_v = src_v->ne[0];
|
|
10414
|
+
const int64_t H = src_v->ne[1];
|
|
10415
|
+
const int64_t n_tokens = src_v->ne[2];
|
|
10416
|
+
const int64_t n_seqs = src_v->ne[3];
|
|
10417
|
+
|
|
10418
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
|
|
10419
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
|
|
10420
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
|
|
10421
|
+
GGML_ASSERT(ggml_is_contiguous(src_g));
|
|
10422
|
+
GGML_ASSERT(ggml_is_contiguous(src_beta));
|
|
10423
|
+
GGML_ASSERT(ggml_is_contiguous(src_state));
|
|
10424
|
+
|
|
10425
|
+
GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
|
|
10426
|
+
GGML_ASSERT(src_beta->ne[0] == 1);
|
|
10427
|
+
|
|
10428
|
+
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
|
|
10429
|
+
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
|
|
10430
|
+
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
|
|
10431
|
+
GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
|
|
10432
|
+
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
|
|
10433
|
+
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
|
|
10434
|
+
GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
|
|
10435
|
+
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
|
|
10436
|
+
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
|
|
10437
|
+
|
|
10438
|
+
const bool kda = (neg0 == S_v);
|
|
10439
|
+
|
|
10440
|
+
// scratch layout per thread: [delta(S_v)]
|
|
10441
|
+
const int64_t scratch_per_thread = S_v;
|
|
10442
|
+
const int ith = params->ith;
|
|
10443
|
+
|
|
10444
|
+
float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
|
|
10445
|
+
|
|
10446
|
+
// output layout: [attn_scores | new_states]
|
|
10447
|
+
// attn_scores: S_v * H * n_tokens * n_seqs floats
|
|
10448
|
+
// new_states: S_v * S_v * H * n_seqs floats
|
|
10449
|
+
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
|
10450
|
+
float * attn_out_base = (float *)dst->data;
|
|
10451
|
+
float * state_out_base = (float *)dst->data + attn_score_elems;
|
|
10452
|
+
|
|
10453
|
+
const float * state_in_base = (const float *)src_state->data;
|
|
10454
|
+
|
|
10455
|
+
//const int64_t rq1 = nev1 / neq1;
|
|
10456
|
+
//const int64_t rk1 = nev1 / nek1;
|
|
10457
|
+
const int64_t rq3 = nev3 / neq3;
|
|
10458
|
+
const int64_t rk3 = nev3 / nek3;
|
|
10459
|
+
|
|
10460
|
+
const float scale = 1.0f / sqrtf((float) S_v);
|
|
10461
|
+
|
|
10462
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
10463
|
+
const int64_t iv1 = ir % H; // head_index
|
|
10464
|
+
const int64_t iv3 = ir / H; // sequence
|
|
10465
|
+
|
|
10466
|
+
const int64_t iq1 = iv1 % neq1;
|
|
10467
|
+
const int64_t ik1 = iv1 % nek1;
|
|
10468
|
+
|
|
10469
|
+
const int64_t iq3 = iv3 / rq3;
|
|
10470
|
+
const int64_t ik3 = iv3 / rk3;
|
|
10471
|
+
|
|
10472
|
+
float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
|
|
10473
|
+
|
|
10474
|
+
// copy input state into output buffer and operate in-place
|
|
10475
|
+
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
|
|
10476
|
+
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
|
|
10477
|
+
|
|
10478
|
+
// attn output pointer for first token of this (head, seq)
|
|
10479
|
+
float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
|
|
10480
|
+
|
|
10481
|
+
for (int64_t t = 0; t < n_tokens; t++) {
|
|
10482
|
+
const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
|
|
10483
|
+
const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
|
|
10484
|
+
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
|
|
10485
|
+
|
|
10486
|
+
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
|
|
10487
|
+
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
|
|
10488
|
+
|
|
10489
|
+
// state is stored transposed: s_out[j*S_v + i] = S[i][j]
|
|
10490
|
+
// so row j of s_out = column j of S (contiguous access)
|
|
10491
|
+
|
|
10492
|
+
if (kda) {
|
|
10493
|
+
// precompute exp(g) into delta scratch (reused below)
|
|
10494
|
+
for (int64_t i = 0; i < S_v; ++i) {
|
|
10495
|
+
delta[i] = expf(g_d[i]);
|
|
10496
|
+
}
|
|
10497
|
+
// S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
|
|
10498
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10499
|
+
ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
|
|
10500
|
+
}
|
|
10501
|
+
} else {
|
|
10502
|
+
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
|
|
10503
|
+
}
|
|
10504
|
+
|
|
10505
|
+
// delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
|
|
10506
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10507
|
+
float sum = 0.0f;
|
|
10508
|
+
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
|
|
10509
|
+
delta[j] = (v_d[j] - sum) * beta_val;
|
|
10510
|
+
}
|
|
10511
|
+
|
|
10512
|
+
// outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
|
|
10513
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10514
|
+
ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
|
|
10515
|
+
}
|
|
10516
|
+
|
|
10517
|
+
// attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
|
|
10518
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10519
|
+
float sum = 0.0f;
|
|
10520
|
+
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
|
|
10521
|
+
attn_data[j] = sum * scale;
|
|
10522
|
+
}
|
|
10523
|
+
|
|
10524
|
+
attn_data += S_v * H; // advance to next token
|
|
10525
|
+
}
|
|
10526
|
+
}
|
|
10527
|
+
}
|
|
10528
|
+
|
|
10529
|
+
|
|
10530
|
+
static void ggml_compute_forward_gated_delta_net_f32(
|
|
10531
|
+
const ggml_compute_params * params,
|
|
10532
|
+
ggml_tensor * dst) {
|
|
10533
|
+
|
|
10534
|
+
ggml_tensor * V = dst->src[2];
|
|
10535
|
+
int64_t nr = V->ne[1] * V->ne[3];
|
|
10536
|
+
|
|
10537
|
+
// disable for NUMA
|
|
10538
|
+
const bool disable_chunking = ggml_is_numa();
|
|
10539
|
+
|
|
10540
|
+
int nth = params->nth;
|
|
10541
|
+
int ith = params->ith;
|
|
10542
|
+
|
|
10543
|
+
// 4x chunks per thread
|
|
10544
|
+
int nth_scaled = nth * 4;
|
|
10545
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
10546
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
10547
|
+
|
|
10548
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
10549
|
+
nchunk = nth;
|
|
10550
|
+
}
|
|
10551
|
+
|
|
10552
|
+
if (ith == 0) {
|
|
10553
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
10554
|
+
}
|
|
10555
|
+
|
|
10556
|
+
ggml_barrier(params->threadpool);
|
|
10557
|
+
|
|
10558
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
10559
|
+
|
|
10560
|
+
int current_chunk = ith;
|
|
10561
|
+
|
|
10562
|
+
while (current_chunk < nchunk) {
|
|
10563
|
+
const int64_t ir0 = dr * current_chunk;
|
|
10564
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
10565
|
+
|
|
10566
|
+
ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
|
|
10567
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
10568
|
+
}
|
|
10569
|
+
}
|
|
10570
|
+
|
|
10571
|
+
void ggml_compute_forward_gated_delta_net(
|
|
10572
|
+
const ggml_compute_params * params,
|
|
10573
|
+
ggml_tensor * dst) {
|
|
10574
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10575
|
+
|
|
10576
|
+
switch (src0->type) {
|
|
10577
|
+
case GGML_TYPE_F32:
|
|
10578
|
+
{
|
|
10579
|
+
ggml_compute_forward_gated_delta_net_f32(params, dst);
|
|
10580
|
+
} break;
|
|
10581
|
+
default:
|
|
10582
|
+
{
|
|
10583
|
+
GGML_ABORT("fatal error");
|
|
10584
|
+
}
|
|
10585
|
+
}
|
|
10586
|
+
}
|
|
10587
|
+
|
|
9873
10588
|
// ggml_compute_forward_rwkv_wkv7
|
|
9874
10589
|
|
|
9875
10590
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
@@ -10195,7 +10910,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|
|
10195
10910
|
assert(!isnan(s0[i]));
|
|
10196
10911
|
assert(!isnan(s1[i]));
|
|
10197
10912
|
}
|
|
10198
|
-
#endif
|
|
10913
|
+
#endif // NDEBUG
|
|
10199
10914
|
|
|
10200
10915
|
float max = -INFINITY;
|
|
10201
10916
|
ggml_vec_max_f32(nc, &max, s0);
|
|
@@ -10214,7 +10929,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|
|
10214
10929
|
assert(!isnan(st[i]));
|
|
10215
10930
|
assert(!isinf(st[i]));
|
|
10216
10931
|
}
|
|
10217
|
-
#endif
|
|
10932
|
+
#endif // NDEBUG
|
|
10218
10933
|
}
|
|
10219
10934
|
sums[ith] = sum_thread;
|
|
10220
10935
|
ggml_barrier(params->threadpool);
|
|
@@ -10287,7 +11002,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
10287
11002
|
assert(!isnan(s0[i]));
|
|
10288
11003
|
assert(!isnan(s1[i]));
|
|
10289
11004
|
}
|
|
10290
|
-
#endif
|
|
11005
|
+
#endif // NDEBUG
|
|
10291
11006
|
|
|
10292
11007
|
// soft_max
|
|
10293
11008
|
float max = -INFINITY;
|
|
@@ -10305,7 +11020,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
10305
11020
|
assert(!isnan(ds0[i]));
|
|
10306
11021
|
assert(!isinf(ds0[i]));
|
|
10307
11022
|
}
|
|
10308
|
-
#endif
|
|
11023
|
+
#endif // NDEBUG
|
|
10309
11024
|
}
|
|
10310
11025
|
}
|
|
10311
11026
|
|