whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -3,12 +3,14 @@
|
|
|
3
3
|
#include "ggml-cpu.h"
|
|
4
4
|
#include "ggml-impl.h"
|
|
5
5
|
#include "binary-ops.h"
|
|
6
|
+
#include "simd-gemm.h"
|
|
6
7
|
#include "ggml.h"
|
|
7
8
|
#include "unary-ops.h"
|
|
8
9
|
#include "vec.h"
|
|
9
10
|
|
|
10
|
-
#include <float.h>
|
|
11
11
|
#include <algorithm>
|
|
12
|
+
#include <cfloat>
|
|
13
|
+
#include <cmath>
|
|
12
14
|
|
|
13
15
|
// ggml_compute_forward_dup
|
|
14
16
|
|
|
@@ -373,7 +375,7 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
373
375
|
const size_t rs = ne00 * type_size;
|
|
374
376
|
|
|
375
377
|
if (nb00 == type_size) {
|
|
376
|
-
// src0 is
|
|
378
|
+
// src0 is contiguous on first dimension, copy by rows
|
|
377
379
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
378
380
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
379
381
|
id += rs * ir0;
|
|
@@ -668,6 +670,7 @@ void ggml_compute_forward_add(
|
|
|
668
670
|
case GGML_TYPE_Q5_1:
|
|
669
671
|
case GGML_TYPE_Q8_0:
|
|
670
672
|
case GGML_TYPE_MXFP4:
|
|
673
|
+
case GGML_TYPE_NVFP4:
|
|
671
674
|
case GGML_TYPE_Q2_K:
|
|
672
675
|
case GGML_TYPE_Q3_K:
|
|
673
676
|
case GGML_TYPE_Q4_K:
|
|
@@ -1117,6 +1120,7 @@ void ggml_compute_forward_add1(
|
|
|
1117
1120
|
case GGML_TYPE_Q8_0:
|
|
1118
1121
|
case GGML_TYPE_Q8_1:
|
|
1119
1122
|
case GGML_TYPE_MXFP4:
|
|
1123
|
+
case GGML_TYPE_NVFP4:
|
|
1120
1124
|
case GGML_TYPE_Q2_K:
|
|
1121
1125
|
case GGML_TYPE_Q3_K:
|
|
1122
1126
|
case GGML_TYPE_Q4_K:
|
|
@@ -1245,6 +1249,7 @@ void ggml_compute_forward_acc(
|
|
|
1245
1249
|
case GGML_TYPE_Q8_0:
|
|
1246
1250
|
case GGML_TYPE_Q8_1:
|
|
1247
1251
|
case GGML_TYPE_MXFP4:
|
|
1252
|
+
case GGML_TYPE_NVFP4:
|
|
1248
1253
|
case GGML_TYPE_Q2_K:
|
|
1249
1254
|
case GGML_TYPE_Q3_K:
|
|
1250
1255
|
case GGML_TYPE_Q4_K:
|
|
@@ -1394,6 +1399,56 @@ void ggml_compute_forward_sum(
|
|
|
1394
1399
|
}
|
|
1395
1400
|
}
|
|
1396
1401
|
|
|
1402
|
+
// ggml_compute_forward_cumsum
|
|
1403
|
+
|
|
1404
|
+
static void ggml_compute_forward_cumsum_f32(
|
|
1405
|
+
const ggml_compute_params * params,
|
|
1406
|
+
ggml_tensor * dst) {
|
|
1407
|
+
|
|
1408
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1409
|
+
|
|
1410
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
1411
|
+
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
|
1412
|
+
|
|
1413
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
1414
|
+
|
|
1415
|
+
GGML_ASSERT(ne0 == ne00);
|
|
1416
|
+
GGML_ASSERT(ne1 == ne01);
|
|
1417
|
+
GGML_ASSERT(ne2 == ne02);
|
|
1418
|
+
GGML_ASSERT(ne3 == ne03);
|
|
1419
|
+
|
|
1420
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
1421
|
+
|
|
1422
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
1423
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
1424
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
1425
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
1426
|
+
|
|
1427
|
+
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
1428
|
+
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
1429
|
+
|
|
1430
|
+
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
|
1431
|
+
}
|
|
1432
|
+
}
|
|
1433
|
+
|
|
1434
|
+
void ggml_compute_forward_cumsum(
|
|
1435
|
+
const ggml_compute_params * params,
|
|
1436
|
+
ggml_tensor * dst) {
|
|
1437
|
+
|
|
1438
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1439
|
+
|
|
1440
|
+
switch (src0->type) {
|
|
1441
|
+
case GGML_TYPE_F32:
|
|
1442
|
+
{
|
|
1443
|
+
ggml_compute_forward_cumsum_f32(params, dst);
|
|
1444
|
+
} break;
|
|
1445
|
+
default:
|
|
1446
|
+
{
|
|
1447
|
+
GGML_ABORT("fatal error");
|
|
1448
|
+
}
|
|
1449
|
+
}
|
|
1450
|
+
}
|
|
1451
|
+
|
|
1397
1452
|
// ggml_compute_forward_sum_rows
|
|
1398
1453
|
|
|
1399
1454
|
static void ggml_compute_forward_sum_rows_f32(
|
|
@@ -1743,7 +1798,7 @@ void ggml_compute_forward_repeat(
|
|
|
1743
1798
|
{
|
|
1744
1799
|
ggml_compute_forward_repeat_f32(params, dst);
|
|
1745
1800
|
} break;
|
|
1746
|
-
// TODO: templateify the
|
|
1801
|
+
// TODO: templateify the implementation and support for I64
|
|
1747
1802
|
// ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
|
|
1748
1803
|
//case GGML_TYPE_I64:
|
|
1749
1804
|
// {
|
|
@@ -2045,10 +2100,14 @@ static void ggml_compute_forward_gelu_f32(
|
|
|
2045
2100
|
|
|
2046
2101
|
const ggml_tensor * src0 = dst->src[0];
|
|
2047
2102
|
|
|
2048
|
-
assert(
|
|
2049
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2103
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2050
2104
|
assert(ggml_are_same_shape(src0, dst));
|
|
2051
2105
|
|
|
2106
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2107
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2108
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2109
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2110
|
+
|
|
2052
2111
|
const int ith = params->ith;
|
|
2053
2112
|
const int nth = params->nth;
|
|
2054
2113
|
|
|
@@ -2062,19 +2121,23 @@ static void ggml_compute_forward_gelu_f32(
|
|
|
2062
2121
|
const int ir0 = dr*ith;
|
|
2063
2122
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2064
2123
|
|
|
2065
|
-
for (int
|
|
2124
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2125
|
+
const int i3 = ir/(ne02*ne01);
|
|
2126
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2127
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2128
|
+
|
|
2066
2129
|
ggml_vec_gelu_f32(nc,
|
|
2067
|
-
(float *) ((char *) dst->data + i1*
|
|
2068
|
-
(float *) ((char *) src0->data + i1*
|
|
2130
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2131
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2069
2132
|
|
|
2070
2133
|
#ifndef NDEBUG
|
|
2071
2134
|
for (int k = 0; k < nc; k++) {
|
|
2072
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2135
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2073
2136
|
GGML_UNUSED(x);
|
|
2074
2137
|
assert(!isnan(x));
|
|
2075
2138
|
assert(!isinf(x));
|
|
2076
2139
|
}
|
|
2077
|
-
#endif
|
|
2140
|
+
#endif // NDEBUG
|
|
2078
2141
|
}
|
|
2079
2142
|
}
|
|
2080
2143
|
|
|
@@ -2084,10 +2147,14 @@ static void ggml_compute_forward_gelu_f16(
|
|
|
2084
2147
|
|
|
2085
2148
|
const ggml_tensor * src0 = dst->src[0];
|
|
2086
2149
|
|
|
2087
|
-
assert(
|
|
2088
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2150
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2089
2151
|
assert(ggml_are_same_shape(src0, dst));
|
|
2090
2152
|
|
|
2153
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2154
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2155
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2156
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2157
|
+
|
|
2091
2158
|
const int ith = params->ith;
|
|
2092
2159
|
const int nth = params->nth;
|
|
2093
2160
|
|
|
@@ -2101,20 +2168,24 @@ static void ggml_compute_forward_gelu_f16(
|
|
|
2101
2168
|
const int ir0 = dr*ith;
|
|
2102
2169
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2103
2170
|
|
|
2104
|
-
for (int
|
|
2171
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2172
|
+
const int i3 = ir/(ne02*ne01);
|
|
2173
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2174
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2175
|
+
|
|
2105
2176
|
ggml_vec_gelu_f16(nc,
|
|
2106
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2107
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2177
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2178
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2108
2179
|
|
|
2109
2180
|
#ifndef NDEBUG
|
|
2110
2181
|
for (int k = 0; k < nc; k++) {
|
|
2111
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2182
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2112
2183
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2113
2184
|
GGML_UNUSED(v);
|
|
2114
2185
|
assert(!isnan(v));
|
|
2115
2186
|
assert(!isinf(v));
|
|
2116
2187
|
}
|
|
2117
|
-
#endif
|
|
2188
|
+
#endif // NDEBUG
|
|
2118
2189
|
}
|
|
2119
2190
|
}
|
|
2120
2191
|
|
|
@@ -2140,6 +2211,83 @@ static void ggml_compute_forward_gelu(
|
|
|
2140
2211
|
}
|
|
2141
2212
|
}
|
|
2142
2213
|
|
|
2214
|
+
// ggml_compute_fill
|
|
2215
|
+
|
|
2216
|
+
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2217
|
+
const float c = ggml_get_op_params_f32(dst, 0);
|
|
2218
|
+
|
|
2219
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
|
2220
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
|
2221
|
+
|
|
2222
|
+
const auto [ir0, ir1] = get_thread_range(params, dst);
|
|
2223
|
+
|
|
2224
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2225
|
+
const int64_t i03 = ir/(ne2*ne1);
|
|
2226
|
+
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
|
2227
|
+
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
|
2228
|
+
|
|
2229
|
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2230
|
+
|
|
2231
|
+
ggml_vec_set_f32(ne0, dst_ptr, c);
|
|
2232
|
+
}
|
|
2233
|
+
}
|
|
2234
|
+
|
|
2235
|
+
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2236
|
+
ggml_compute_forward_fill_f32(params, dst);
|
|
2237
|
+
}
|
|
2238
|
+
|
|
2239
|
+
// ggml_compute_tri
|
|
2240
|
+
|
|
2241
|
+
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2242
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2243
|
+
|
|
2244
|
+
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
|
2245
|
+
|
|
2246
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2247
|
+
|
|
2248
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
2249
|
+
|
|
2250
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
2251
|
+
|
|
2252
|
+
bool (*bipred)(int, int);
|
|
2253
|
+
|
|
2254
|
+
switch (ttype) {
|
|
2255
|
+
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
|
2256
|
+
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
|
2257
|
+
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
|
2258
|
+
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
|
2259
|
+
default: GGML_ABORT("invalid tri type");
|
|
2260
|
+
}
|
|
2261
|
+
|
|
2262
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2263
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
2264
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
2265
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
2266
|
+
|
|
2267
|
+
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
2268
|
+
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2269
|
+
|
|
2270
|
+
for (int i0 = 0; i0 < ne0; ++i0) {
|
|
2271
|
+
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
|
2272
|
+
}
|
|
2273
|
+
}
|
|
2274
|
+
}
|
|
2275
|
+
|
|
2276
|
+
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2277
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2278
|
+
|
|
2279
|
+
switch (src0->type) {
|
|
2280
|
+
case GGML_TYPE_F32:
|
|
2281
|
+
{
|
|
2282
|
+
ggml_compute_forward_tri_f32(params, dst);
|
|
2283
|
+
} break;
|
|
2284
|
+
default:
|
|
2285
|
+
{
|
|
2286
|
+
GGML_ABORT("fatal error");
|
|
2287
|
+
}
|
|
2288
|
+
}
|
|
2289
|
+
}
|
|
2290
|
+
|
|
2143
2291
|
// ggml_compute_forward_gelu_erf
|
|
2144
2292
|
|
|
2145
2293
|
static void ggml_compute_forward_gelu_erf_f32(
|
|
@@ -2148,10 +2296,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
|
|
2148
2296
|
|
|
2149
2297
|
const ggml_tensor * src0 = dst->src[0];
|
|
2150
2298
|
|
|
2151
|
-
assert(
|
|
2152
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2299
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2153
2300
|
assert(ggml_are_same_shape(src0, dst));
|
|
2154
2301
|
|
|
2302
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2303
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2304
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2305
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2306
|
+
|
|
2155
2307
|
const int ith = params->ith;
|
|
2156
2308
|
const int nth = params->nth;
|
|
2157
2309
|
|
|
@@ -2165,19 +2317,23 @@ static void ggml_compute_forward_gelu_erf_f32(
|
|
|
2165
2317
|
const int ir0 = dr*ith;
|
|
2166
2318
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2167
2319
|
|
|
2168
|
-
for (int
|
|
2320
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2321
|
+
const int i3 = ir/(ne02*ne01);
|
|
2322
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2323
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2324
|
+
|
|
2169
2325
|
ggml_vec_gelu_erf_f32(nc,
|
|
2170
|
-
(float *) ((char *) dst->data + i1*
|
|
2171
|
-
(float *) ((char *) src0->data + i1*
|
|
2326
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2327
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2172
2328
|
|
|
2173
2329
|
#ifndef NDEBUG
|
|
2174
2330
|
for (int k = 0; k < nc; k++) {
|
|
2175
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2331
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2176
2332
|
GGML_UNUSED(x);
|
|
2177
2333
|
assert(!isnan(x));
|
|
2178
2334
|
assert(!isinf(x));
|
|
2179
2335
|
}
|
|
2180
|
-
#endif
|
|
2336
|
+
#endif // NDEBUG
|
|
2181
2337
|
}
|
|
2182
2338
|
}
|
|
2183
2339
|
|
|
@@ -2187,10 +2343,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|
|
2187
2343
|
|
|
2188
2344
|
const ggml_tensor * src0 = dst->src[0];
|
|
2189
2345
|
|
|
2190
|
-
assert(
|
|
2191
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2346
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2192
2347
|
assert(ggml_are_same_shape(src0, dst));
|
|
2193
2348
|
|
|
2349
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2350
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2351
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2352
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2353
|
+
|
|
2194
2354
|
const int ith = params->ith;
|
|
2195
2355
|
const int nth = params->nth;
|
|
2196
2356
|
|
|
@@ -2204,20 +2364,24 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|
|
2204
2364
|
const int ir0 = dr*ith;
|
|
2205
2365
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2206
2366
|
|
|
2207
|
-
for (int
|
|
2367
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2368
|
+
const int i3 = ir/(ne02*ne01);
|
|
2369
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2370
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2371
|
+
|
|
2208
2372
|
ggml_vec_gelu_erf_f16(nc,
|
|
2209
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2210
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2373
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2374
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2211
2375
|
|
|
2212
2376
|
#ifndef NDEBUG
|
|
2213
2377
|
for (int k = 0; k < nc; k++) {
|
|
2214
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2378
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2215
2379
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2216
2380
|
GGML_UNUSED(v);
|
|
2217
2381
|
assert(!isnan(v));
|
|
2218
2382
|
assert(!isinf(v));
|
|
2219
2383
|
}
|
|
2220
|
-
#endif
|
|
2384
|
+
#endif // NDEBUG
|
|
2221
2385
|
}
|
|
2222
2386
|
}
|
|
2223
2387
|
|
|
@@ -2251,10 +2415,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|
|
2251
2415
|
|
|
2252
2416
|
const ggml_tensor * src0 = dst->src[0];
|
|
2253
2417
|
|
|
2254
|
-
assert(
|
|
2255
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2418
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2256
2419
|
assert(ggml_are_same_shape(src0, dst));
|
|
2257
2420
|
|
|
2421
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2422
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2423
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2424
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2425
|
+
|
|
2258
2426
|
const int ith = params->ith;
|
|
2259
2427
|
const int nth = params->nth;
|
|
2260
2428
|
|
|
@@ -2268,19 +2436,23 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|
|
2268
2436
|
const int ir0 = dr*ith;
|
|
2269
2437
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2270
2438
|
|
|
2271
|
-
for (int
|
|
2439
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2440
|
+
const int i3 = ir/(ne02*ne01);
|
|
2441
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2442
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2443
|
+
|
|
2272
2444
|
ggml_vec_gelu_quick_f32(nc,
|
|
2273
|
-
(float *) ((char *) dst->data + i1*
|
|
2274
|
-
(float *) ((char *) src0->data + i1*
|
|
2445
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2446
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2275
2447
|
|
|
2276
2448
|
#ifndef NDEBUG
|
|
2277
2449
|
for (int k = 0; k < nc; k++) {
|
|
2278
|
-
const float x = ((float *) ((char *) dst->data + i1*(
|
|
2450
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2279
2451
|
GGML_UNUSED(x);
|
|
2280
2452
|
assert(!isnan(x));
|
|
2281
2453
|
assert(!isinf(x));
|
|
2282
2454
|
}
|
|
2283
|
-
#endif
|
|
2455
|
+
#endif // NDEBUG
|
|
2284
2456
|
}
|
|
2285
2457
|
}
|
|
2286
2458
|
|
|
@@ -2290,10 +2462,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|
|
2290
2462
|
|
|
2291
2463
|
const ggml_tensor * src0 = dst->src[0];
|
|
2292
2464
|
|
|
2293
|
-
assert(
|
|
2294
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2465
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2295
2466
|
assert(ggml_are_same_shape(src0, dst));
|
|
2296
2467
|
|
|
2468
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2469
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2470
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2471
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2472
|
+
|
|
2297
2473
|
const int ith = params->ith;
|
|
2298
2474
|
const int nth = params->nth;
|
|
2299
2475
|
|
|
@@ -2307,20 +2483,24 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|
|
2307
2483
|
const int ir0 = dr*ith;
|
|
2308
2484
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2309
2485
|
|
|
2310
|
-
for (int
|
|
2486
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2487
|
+
const int i3 = ir/(ne02*ne01);
|
|
2488
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2489
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2490
|
+
|
|
2311
2491
|
ggml_vec_gelu_quick_f16(nc,
|
|
2312
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2313
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2492
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2493
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2314
2494
|
|
|
2315
2495
|
#ifndef NDEBUG
|
|
2316
2496
|
for (int k = 0; k < nc; k++) {
|
|
2317
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2497
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2318
2498
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2319
2499
|
GGML_UNUSED(v);
|
|
2320
2500
|
assert(!isnan(v));
|
|
2321
2501
|
assert(!isinf(v));
|
|
2322
2502
|
}
|
|
2323
|
-
#endif
|
|
2503
|
+
#endif // NDEBUG
|
|
2324
2504
|
}
|
|
2325
2505
|
}
|
|
2326
2506
|
|
|
@@ -2354,10 +2534,14 @@ static void ggml_compute_forward_silu_f32(
|
|
|
2354
2534
|
|
|
2355
2535
|
const ggml_tensor * src0 = dst->src[0];
|
|
2356
2536
|
|
|
2357
|
-
assert(
|
|
2358
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2537
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2359
2538
|
assert(ggml_are_same_shape(src0, dst));
|
|
2360
2539
|
|
|
2540
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2541
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2542
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2543
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2544
|
+
|
|
2361
2545
|
const int ith = params->ith;
|
|
2362
2546
|
const int nth = params->nth;
|
|
2363
2547
|
|
|
@@ -2371,19 +2555,23 @@ static void ggml_compute_forward_silu_f32(
|
|
|
2371
2555
|
const int ir0 = dr*ith;
|
|
2372
2556
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2373
2557
|
|
|
2374
|
-
for (int
|
|
2558
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2559
|
+
const int i3 = ir/(ne02*ne01);
|
|
2560
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2561
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2562
|
+
|
|
2375
2563
|
ggml_vec_silu_f32(nc,
|
|
2376
|
-
(float *) ((char *) dst->data + i1*
|
|
2377
|
-
(float *) ((char *) src0->data + i1*
|
|
2564
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2565
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2378
2566
|
|
|
2379
2567
|
#ifndef NDEBUG
|
|
2380
2568
|
for (int k = 0; k < nc; k++) {
|
|
2381
|
-
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
|
2569
|
+
const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
|
|
2382
2570
|
GGML_UNUSED(x);
|
|
2383
2571
|
assert(!isnan(x));
|
|
2384
2572
|
assert(!isinf(x));
|
|
2385
2573
|
}
|
|
2386
|
-
#endif
|
|
2574
|
+
#endif // NDEBUG
|
|
2387
2575
|
}
|
|
2388
2576
|
}
|
|
2389
2577
|
|
|
@@ -2393,10 +2581,14 @@ static void ggml_compute_forward_silu_f16(
|
|
|
2393
2581
|
|
|
2394
2582
|
const ggml_tensor * src0 = dst->src[0];
|
|
2395
2583
|
|
|
2396
|
-
assert(
|
|
2397
|
-
assert(ggml_is_contiguous_1(dst));
|
|
2584
|
+
assert(ggml_is_contiguous_rows(src0));
|
|
2398
2585
|
assert(ggml_are_same_shape(src0, dst));
|
|
2399
2586
|
|
|
2587
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
|
2588
|
+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
|
2589
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
2590
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
2591
|
+
|
|
2400
2592
|
const int ith = params->ith;
|
|
2401
2593
|
const int nth = params->nth;
|
|
2402
2594
|
|
|
@@ -2410,20 +2602,24 @@ static void ggml_compute_forward_silu_f16(
|
|
|
2410
2602
|
const int ir0 = dr*ith;
|
|
2411
2603
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
2412
2604
|
|
|
2413
|
-
for (int
|
|
2605
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
2606
|
+
const int i3 = ir/(ne02*ne01);
|
|
2607
|
+
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
|
2608
|
+
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
|
2609
|
+
|
|
2414
2610
|
ggml_vec_silu_f16(nc,
|
|
2415
|
-
(ggml_fp16_t *) ((char *) dst->data + i1*
|
|
2416
|
-
(ggml_fp16_t *) ((char *) src0->data + i1*
|
|
2611
|
+
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
|
2612
|
+
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
|
2417
2613
|
|
|
2418
2614
|
#ifndef NDEBUG
|
|
2419
2615
|
for (int k = 0; k < nc; k++) {
|
|
2420
|
-
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
|
2616
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
|
|
2421
2617
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2422
2618
|
GGML_UNUSED(v);
|
|
2423
2619
|
assert(!isnan(v));
|
|
2424
2620
|
assert(!isinf(v));
|
|
2425
2621
|
}
|
|
2426
|
-
#endif
|
|
2622
|
+
#endif // NDEBUG
|
|
2427
2623
|
}
|
|
2428
2624
|
}
|
|
2429
2625
|
|
|
@@ -2573,7 +2769,7 @@ static void ggml_compute_forward_silu_back_f32(
|
|
|
2573
2769
|
assert(!isnan(x));
|
|
2574
2770
|
assert(!isinf(x));
|
|
2575
2771
|
}
|
|
2576
|
-
#endif
|
|
2772
|
+
#endif // NDEBUG
|
|
2577
2773
|
}
|
|
2578
2774
|
}
|
|
2579
2775
|
|
|
@@ -2609,7 +2805,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
2609
2805
|
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
2610
2806
|
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
2611
2807
|
|
|
2612
|
-
|
|
2808
|
+
#ifndef NDEBUG
|
|
2613
2809
|
for (int k = 0; k < nc; k++) {
|
|
2614
2810
|
const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2615
2811
|
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
@@ -2617,7 +2813,7 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
2617
2813
|
assert(!isnan(v));
|
|
2618
2814
|
assert(!isinf(v));
|
|
2619
2815
|
}
|
|
2620
|
-
|
|
2816
|
+
#endif // NDEBUG
|
|
2621
2817
|
}
|
|
2622
2818
|
}
|
|
2623
2819
|
|
|
@@ -2700,7 +2896,7 @@ static void ggml_compute_forward_reglu_f32(
|
|
|
2700
2896
|
assert(!isnan(x));
|
|
2701
2897
|
assert(!isinf(x));
|
|
2702
2898
|
}
|
|
2703
|
-
#endif
|
|
2899
|
+
#endif // NDEBUG
|
|
2704
2900
|
}
|
|
2705
2901
|
}
|
|
2706
2902
|
|
|
@@ -2760,7 +2956,7 @@ static void ggml_compute_forward_reglu_f16(
|
|
|
2760
2956
|
assert(!isnan(v));
|
|
2761
2957
|
assert(!isinf(v));
|
|
2762
2958
|
}
|
|
2763
|
-
#endif
|
|
2959
|
+
#endif // NDEBUG
|
|
2764
2960
|
}
|
|
2765
2961
|
}
|
|
2766
2962
|
|
|
@@ -2843,7 +3039,7 @@ static void ggml_compute_forward_geglu_f32(
|
|
|
2843
3039
|
assert(!isnan(x));
|
|
2844
3040
|
assert(!isinf(x));
|
|
2845
3041
|
}
|
|
2846
|
-
#endif
|
|
3042
|
+
#endif // NDEBUG
|
|
2847
3043
|
}
|
|
2848
3044
|
}
|
|
2849
3045
|
|
|
@@ -2903,7 +3099,7 @@ static void ggml_compute_forward_geglu_f16(
|
|
|
2903
3099
|
assert(!isnan(v));
|
|
2904
3100
|
assert(!isinf(v));
|
|
2905
3101
|
}
|
|
2906
|
-
#endif
|
|
3102
|
+
#endif // NDEBUG
|
|
2907
3103
|
}
|
|
2908
3104
|
}
|
|
2909
3105
|
|
|
@@ -2986,7 +3182,7 @@ static void ggml_compute_forward_swiglu_f32(
|
|
|
2986
3182
|
assert(!isnan(x));
|
|
2987
3183
|
assert(!isinf(x));
|
|
2988
3184
|
}
|
|
2989
|
-
#endif
|
|
3185
|
+
#endif // NDEBUG
|
|
2990
3186
|
}
|
|
2991
3187
|
}
|
|
2992
3188
|
|
|
@@ -3046,7 +3242,7 @@ static void ggml_compute_forward_swiglu_f16(
|
|
|
3046
3242
|
assert(!isnan(v));
|
|
3047
3243
|
assert(!isinf(v));
|
|
3048
3244
|
}
|
|
3049
|
-
#endif
|
|
3245
|
+
#endif // NDEBUG
|
|
3050
3246
|
}
|
|
3051
3247
|
}
|
|
3052
3248
|
|
|
@@ -3137,7 +3333,7 @@ static void ggml_compute_forward_swiglu_oai_f32(
|
|
|
3137
3333
|
assert(!isnan(x));
|
|
3138
3334
|
assert(!isinf(x));
|
|
3139
3335
|
}
|
|
3140
|
-
#endif
|
|
3336
|
+
#endif // NDEBUG
|
|
3141
3337
|
}
|
|
3142
3338
|
}
|
|
3143
3339
|
|
|
@@ -3216,7 +3412,7 @@ static void ggml_compute_forward_geglu_erf_f32(
|
|
|
3216
3412
|
assert(!isnan(x));
|
|
3217
3413
|
assert(!isinf(x));
|
|
3218
3414
|
}
|
|
3219
|
-
#endif
|
|
3415
|
+
#endif // NDEBUG
|
|
3220
3416
|
}
|
|
3221
3417
|
}
|
|
3222
3418
|
|
|
@@ -3276,7 +3472,7 @@ static void ggml_compute_forward_geglu_erf_f16(
|
|
|
3276
3472
|
assert(!isnan(v));
|
|
3277
3473
|
assert(!isinf(v));
|
|
3278
3474
|
}
|
|
3279
|
-
#endif
|
|
3475
|
+
#endif // NDEBUG
|
|
3280
3476
|
}
|
|
3281
3477
|
}
|
|
3282
3478
|
|
|
@@ -3359,7 +3555,7 @@ static void ggml_compute_forward_geglu_quick_f32(
|
|
|
3359
3555
|
assert(!isnan(x));
|
|
3360
3556
|
assert(!isinf(x));
|
|
3361
3557
|
}
|
|
3362
|
-
#endif
|
|
3558
|
+
#endif // NDEBUG
|
|
3363
3559
|
}
|
|
3364
3560
|
}
|
|
3365
3561
|
|
|
@@ -3419,7 +3615,7 @@ static void ggml_compute_forward_geglu_quick_f16(
|
|
|
3419
3615
|
assert(!isnan(v));
|
|
3420
3616
|
assert(!isinf(v));
|
|
3421
3617
|
}
|
|
3422
|
-
#endif
|
|
3618
|
+
#endif // NDEBUG
|
|
3423
3619
|
}
|
|
3424
3620
|
}
|
|
3425
3621
|
|
|
@@ -3467,31 +3663,27 @@ static void ggml_compute_forward_norm_f32(
|
|
|
3467
3663
|
|
|
3468
3664
|
GGML_ASSERT(eps >= 0.0f);
|
|
3469
3665
|
|
|
3470
|
-
// TODO: optimize
|
|
3471
3666
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
3472
3667
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
3473
3668
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
3474
3669
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
3475
3670
|
|
|
3476
|
-
|
|
3477
|
-
|
|
3478
|
-
sum += (ggml_float)x[i00];
|
|
3479
|
-
}
|
|
3480
|
-
|
|
3671
|
+
float sum = 0.0;
|
|
3672
|
+
ggml_vec_sum_f32(ne00, &sum, x);
|
|
3481
3673
|
float mean = sum/ne00;
|
|
3482
3674
|
|
|
3483
3675
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3676
|
+
float variance = 0;
|
|
3484
3677
|
|
|
3485
|
-
|
|
3486
|
-
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
|
|
3678
|
+
#ifdef GGML_USE_ACCELERATE
|
|
3679
|
+
mean = -mean;
|
|
3680
|
+
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
|
3681
|
+
vDSP_measqv(y, 1, &variance, ne00);
|
|
3682
|
+
#else
|
|
3683
|
+
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
|
3684
|
+
#endif //GGML_USE_ACCELERATE
|
|
3491
3685
|
|
|
3492
|
-
float variance = sum2/ne00;
|
|
3493
3686
|
const float scale = 1.0f/sqrtf(variance + eps);
|
|
3494
|
-
|
|
3495
3687
|
ggml_vec_scale_f32(ne00, y, scale);
|
|
3496
3688
|
}
|
|
3497
3689
|
}
|
|
@@ -4145,6 +4337,7 @@ void ggml_compute_forward_out_prod(
|
|
|
4145
4337
|
case GGML_TYPE_Q5_1:
|
|
4146
4338
|
case GGML_TYPE_Q8_0:
|
|
4147
4339
|
case GGML_TYPE_MXFP4:
|
|
4340
|
+
case GGML_TYPE_NVFP4:
|
|
4148
4341
|
case GGML_TYPE_Q2_K:
|
|
4149
4342
|
case GGML_TYPE_Q3_K:
|
|
4150
4343
|
case GGML_TYPE_Q4_K:
|
|
@@ -4420,6 +4613,7 @@ void ggml_compute_forward_set(
|
|
|
4420
4613
|
case GGML_TYPE_Q8_0:
|
|
4421
4614
|
case GGML_TYPE_Q8_1:
|
|
4422
4615
|
case GGML_TYPE_MXFP4:
|
|
4616
|
+
case GGML_TYPE_NVFP4:
|
|
4423
4617
|
case GGML_TYPE_Q2_K:
|
|
4424
4618
|
case GGML_TYPE_Q3_K:
|
|
4425
4619
|
case GGML_TYPE_Q4_K:
|
|
@@ -4459,46 +4653,6 @@ void ggml_compute_forward_cont(
|
|
|
4459
4653
|
ggml_compute_forward_dup(params, dst);
|
|
4460
4654
|
}
|
|
4461
4655
|
|
|
4462
|
-
// ggml_compute_forward_reshape
|
|
4463
|
-
|
|
4464
|
-
void ggml_compute_forward_reshape(
|
|
4465
|
-
const ggml_compute_params * params,
|
|
4466
|
-
ggml_tensor * dst) {
|
|
4467
|
-
// NOP
|
|
4468
|
-
GGML_UNUSED(params);
|
|
4469
|
-
GGML_UNUSED(dst);
|
|
4470
|
-
}
|
|
4471
|
-
|
|
4472
|
-
// ggml_compute_forward_view
|
|
4473
|
-
|
|
4474
|
-
void ggml_compute_forward_view(
|
|
4475
|
-
const ggml_compute_params * params,
|
|
4476
|
-
ggml_tensor * dst) {
|
|
4477
|
-
// NOP
|
|
4478
|
-
GGML_UNUSED(params);
|
|
4479
|
-
GGML_UNUSED(dst);
|
|
4480
|
-
}
|
|
4481
|
-
|
|
4482
|
-
// ggml_compute_forward_permute
|
|
4483
|
-
|
|
4484
|
-
void ggml_compute_forward_permute(
|
|
4485
|
-
const ggml_compute_params * params,
|
|
4486
|
-
ggml_tensor * dst) {
|
|
4487
|
-
// NOP
|
|
4488
|
-
GGML_UNUSED(params);
|
|
4489
|
-
GGML_UNUSED(dst);
|
|
4490
|
-
}
|
|
4491
|
-
|
|
4492
|
-
// ggml_compute_forward_transpose
|
|
4493
|
-
|
|
4494
|
-
void ggml_compute_forward_transpose(
|
|
4495
|
-
const ggml_compute_params * params,
|
|
4496
|
-
ggml_tensor * dst) {
|
|
4497
|
-
// NOP
|
|
4498
|
-
GGML_UNUSED(params);
|
|
4499
|
-
GGML_UNUSED(dst);
|
|
4500
|
-
}
|
|
4501
|
-
|
|
4502
4656
|
// ggml_compute_forward_get_rows
|
|
4503
4657
|
|
|
4504
4658
|
static void ggml_compute_forward_get_rows_q(
|
|
@@ -4682,6 +4836,7 @@ void ggml_compute_forward_get_rows(
|
|
|
4682
4836
|
case GGML_TYPE_Q8_0:
|
|
4683
4837
|
case GGML_TYPE_Q8_1:
|
|
4684
4838
|
case GGML_TYPE_MXFP4:
|
|
4839
|
+
case GGML_TYPE_NVFP4:
|
|
4685
4840
|
case GGML_TYPE_Q2_K:
|
|
4686
4841
|
case GGML_TYPE_Q3_K:
|
|
4687
4842
|
case GGML_TYPE_Q4_K:
|
|
@@ -5154,7 +5309,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5154
5309
|
//printf("p[%d] = %f\n", i, p[i]);
|
|
5155
5310
|
assert(!isnan(wp[i]));
|
|
5156
5311
|
}
|
|
5157
|
-
#endif
|
|
5312
|
+
#endif // NDEBUG
|
|
5158
5313
|
|
|
5159
5314
|
float max = -INFINITY;
|
|
5160
5315
|
ggml_vec_max_f32(ne00, &max, wp);
|
|
@@ -5179,7 +5334,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5179
5334
|
assert(!isnan(dp[i]));
|
|
5180
5335
|
assert(!isinf(dp[i]));
|
|
5181
5336
|
}
|
|
5182
|
-
#endif
|
|
5337
|
+
#endif // NDEBUG
|
|
5183
5338
|
}
|
|
5184
5339
|
}
|
|
5185
5340
|
}
|
|
@@ -5253,7 +5408,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
|
|
|
5253
5408
|
assert(!isnan(dy[i]));
|
|
5254
5409
|
assert(!isnan(y[i]));
|
|
5255
5410
|
}
|
|
5256
|
-
#endif
|
|
5411
|
+
#endif // NDEBUG
|
|
5257
5412
|
// Jii = yi - yi*yi
|
|
5258
5413
|
// Jij = -yi*yj
|
|
5259
5414
|
// J = diag(y)-y.T*y
|
|
@@ -5286,7 +5441,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
|
|
|
5286
5441
|
assert(!isnan(dx[i]));
|
|
5287
5442
|
assert(!isinf(dx[i]));
|
|
5288
5443
|
}
|
|
5289
|
-
#endif
|
|
5444
|
+
#endif // NDEBUG
|
|
5290
5445
|
}
|
|
5291
5446
|
}
|
|
5292
5447
|
|
|
@@ -5406,6 +5561,7 @@ void ggml_compute_forward_clamp(
|
|
|
5406
5561
|
case GGML_TYPE_Q8_0:
|
|
5407
5562
|
case GGML_TYPE_Q8_1:
|
|
5408
5563
|
case GGML_TYPE_MXFP4:
|
|
5564
|
+
case GGML_TYPE_NVFP4:
|
|
5409
5565
|
case GGML_TYPE_Q2_K:
|
|
5410
5566
|
case GGML_TYPE_Q3_K:
|
|
5411
5567
|
case GGML_TYPE_Q4_K:
|
|
@@ -5478,7 +5634,7 @@ static void ggml_rope_cache_init(
|
|
|
5478
5634
|
}
|
|
5479
5635
|
|
|
5480
5636
|
static void ggml_mrope_cache_init(
|
|
5481
|
-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
|
5637
|
+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
|
5482
5638
|
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5483
5639
|
float * cache, float sin_sign, float theta_scale) {
|
|
5484
5640
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
@@ -5513,14 +5669,26 @@ static void ggml_mrope_cache_init(
|
|
|
5513
5669
|
}
|
|
5514
5670
|
|
|
5515
5671
|
float theta = theta_t;
|
|
5516
|
-
if (
|
|
5517
|
-
|
|
5518
|
-
|
|
5519
|
-
|
|
5520
|
-
|
|
5521
|
-
|
|
5522
|
-
|
|
5523
|
-
|
|
5672
|
+
if (is_imrope) { // qwen3vl apply interleaved mrope
|
|
5673
|
+
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
|
5674
|
+
theta = theta_h;
|
|
5675
|
+
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
|
5676
|
+
theta = theta_w;
|
|
5677
|
+
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
|
5678
|
+
theta = theta_t;
|
|
5679
|
+
} else {
|
|
5680
|
+
theta = theta_e;
|
|
5681
|
+
}
|
|
5682
|
+
} else {
|
|
5683
|
+
if (sector >= sections[0] && sector < sec_w) {
|
|
5684
|
+
theta = theta_h;
|
|
5685
|
+
}
|
|
5686
|
+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
5687
|
+
theta = theta_w;
|
|
5688
|
+
}
|
|
5689
|
+
else if (sector >= sec_w + sections[2]) {
|
|
5690
|
+
theta = theta_e;
|
|
5691
|
+
}
|
|
5524
5692
|
}
|
|
5525
5693
|
|
|
5526
5694
|
rope_yarn(
|
|
@@ -5535,7 +5703,28 @@ static void ggml_mrope_cache_init(
|
|
|
5535
5703
|
}
|
|
5536
5704
|
}
|
|
5537
5705
|
|
|
5538
|
-
|
|
5706
|
+
|
|
5707
|
+
template<typename T>
|
|
5708
|
+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
|
|
5709
|
+
for (int64_t i0 = 0; i0 < n; i0 += 2) {
|
|
5710
|
+
const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
|
|
5711
|
+
|
|
5712
|
+
const float cos_theta = cache[i0 + 0];
|
|
5713
|
+
const float sin_theta = cache[i0 + 1];
|
|
5714
|
+
|
|
5715
|
+
const T * const src = src_data + ic;
|
|
5716
|
+
T * dst = dst_data + ic;
|
|
5717
|
+
|
|
5718
|
+
const float x0 = type_conversion_table<T>::to_f32(src[0]);
|
|
5719
|
+
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
|
|
5720
|
+
|
|
5721
|
+
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
|
|
5722
|
+
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
|
|
5723
|
+
}
|
|
5724
|
+
}
|
|
5725
|
+
|
|
5726
|
+
template<typename T> //float or ggml_fp16_t
|
|
5727
|
+
static void ggml_compute_forward_rope_flt(
|
|
5539
5728
|
const ggml_compute_params * params,
|
|
5540
5729
|
ggml_tensor * dst,
|
|
5541
5730
|
const bool forward) {
|
|
@@ -5544,6 +5733,9 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5544
5733
|
const ggml_tensor * src1 = dst->src[1];
|
|
5545
5734
|
const ggml_tensor * src2 = dst->src[2];
|
|
5546
5735
|
|
|
5736
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
5737
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
5738
|
+
|
|
5547
5739
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5548
5740
|
int sections[4];
|
|
5549
5741
|
|
|
@@ -5566,7 +5758,8 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5566
5758
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5567
5759
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5568
5760
|
|
|
5569
|
-
GGML_ASSERT(
|
|
5761
|
+
GGML_ASSERT(nb0 == nb00);
|
|
5762
|
+
GGML_ASSERT(nb0 == sizeof(T));
|
|
5570
5763
|
|
|
5571
5764
|
const int ith = params->ith;
|
|
5572
5765
|
const int nth = params->nth;
|
|
@@ -5591,11 +5784,11 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5591
5784
|
float corr_dims[2];
|
|
5592
5785
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5593
5786
|
|
|
5594
|
-
const bool
|
|
5595
|
-
const bool
|
|
5787
|
+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
|
5788
|
+
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
|
5596
5789
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
5597
5790
|
|
|
5598
|
-
if (
|
|
5791
|
+
if (mrope_used) {
|
|
5599
5792
|
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5600
5793
|
}
|
|
5601
5794
|
|
|
@@ -5617,290 +5810,63 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5617
5810
|
|
|
5618
5811
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
5619
5812
|
|
|
5813
|
+
int64_t last_i2 = -1;
|
|
5814
|
+
|
|
5620
5815
|
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5621
5816
|
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5817
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5818
|
+
if (ir++ < ir0) continue; // skip rows mapped to other threads
|
|
5819
|
+
if (ir > ir1) break;
|
|
5622
5820
|
|
|
5623
|
-
|
|
5624
|
-
|
|
5625
|
-
|
|
5626
|
-
|
|
5627
|
-
|
|
5628
|
-
else {
|
|
5629
|
-
const int64_t p_t = pos[i2];
|
|
5630
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
5631
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5632
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5633
|
-
ggml_mrope_cache_init(
|
|
5634
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5635
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5636
|
-
}
|
|
5637
|
-
|
|
5638
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5639
|
-
if (ir++ < ir0) continue;
|
|
5640
|
-
if (ir > ir1) break;
|
|
5641
|
-
|
|
5642
|
-
if (is_neox || is_mrope) {
|
|
5643
|
-
if (is_vision){
|
|
5644
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5645
|
-
const int64_t ic = i0/2;
|
|
5646
|
-
|
|
5647
|
-
const float cos_theta = cache[i0 + 0];
|
|
5648
|
-
const float sin_theta = cache[i0 + 1];
|
|
5649
|
-
|
|
5650
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5651
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5652
|
-
|
|
5653
|
-
const float x0 = src[0];
|
|
5654
|
-
const float x1 = src[n_dims];
|
|
5655
|
-
|
|
5656
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5657
|
-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
5658
|
-
}
|
|
5659
|
-
} else {
|
|
5660
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5661
|
-
const int64_t ic = i0/2;
|
|
5662
|
-
|
|
5663
|
-
const float cos_theta = cache[i0 + 0];
|
|
5664
|
-
const float sin_theta = cache[i0 + 1];
|
|
5665
|
-
|
|
5666
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5667
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5668
|
-
|
|
5669
|
-
const float x0 = src[0];
|
|
5670
|
-
const float x1 = src[n_dims/2];
|
|
5671
|
-
|
|
5672
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5673
|
-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
5674
|
-
}
|
|
5675
|
-
}
|
|
5676
|
-
} else {
|
|
5677
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5678
|
-
const float cos_theta = cache[i0 + 0];
|
|
5679
|
-
const float sin_theta = cache[i0 + 1];
|
|
5680
|
-
|
|
5681
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5682
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5683
|
-
|
|
5684
|
-
const float x0 = src[0];
|
|
5685
|
-
const float x1 = src[1];
|
|
5686
|
-
|
|
5687
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5688
|
-
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
5821
|
+
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5822
|
+
if (last_i2 != i2) {
|
|
5823
|
+
if (!mrope_used) {
|
|
5824
|
+
const int64_t p = pos[i2];
|
|
5825
|
+
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5689
5826
|
}
|
|
5690
|
-
|
|
5691
|
-
|
|
5692
|
-
|
|
5693
|
-
|
|
5694
|
-
const int64_t
|
|
5695
|
-
|
|
5696
|
-
|
|
5697
|
-
|
|
5698
|
-
|
|
5699
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5700
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5701
|
-
|
|
5702
|
-
const float x0 = src[0];
|
|
5703
|
-
const float x1 = src[n_dims];
|
|
5704
|
-
|
|
5705
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5706
|
-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
5827
|
+
else {
|
|
5828
|
+
const int64_t p_t = pos[i2];
|
|
5829
|
+
const int64_t p_h = pos[i2 + ne2];
|
|
5830
|
+
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5831
|
+
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5832
|
+
ggml_mrope_cache_init(
|
|
5833
|
+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5834
|
+
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5707
5835
|
}
|
|
5708
|
-
} else {
|
|
5709
|
-
// fill the remain channels with data from src tensor
|
|
5710
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5711
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5712
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5713
5836
|
|
|
5714
|
-
|
|
5715
|
-
dst_data[1] = src[1];
|
|
5716
|
-
}
|
|
5837
|
+
last_i2 = i2;
|
|
5717
5838
|
}
|
|
5718
|
-
}
|
|
5719
|
-
}
|
|
5720
|
-
}
|
|
5721
|
-
}
|
|
5722
|
-
|
|
5723
|
-
// TODO: deduplicate f16/f32 code
|
|
5724
|
-
static void ggml_compute_forward_rope_f16(
|
|
5725
|
-
const ggml_compute_params * params,
|
|
5726
|
-
ggml_tensor * dst,
|
|
5727
|
-
const bool forward) {
|
|
5728
|
-
|
|
5729
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
5730
|
-
const ggml_tensor * src1 = dst->src[1];
|
|
5731
|
-
const ggml_tensor * src2 = dst->src[2];
|
|
5732
|
-
|
|
5733
|
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5734
|
-
int sections[4];
|
|
5735
|
-
|
|
5736
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
5737
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
5738
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
5739
|
-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
5740
|
-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5741
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
5742
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
5743
|
-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
5744
|
-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
5745
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
5746
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
5747
|
-
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
5748
|
-
|
|
5749
|
-
|
|
5750
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
|
5751
|
-
|
|
5752
|
-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5753
|
-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5754
|
-
|
|
5755
|
-
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
|
5756
|
-
|
|
5757
|
-
const int ith = params->ith;
|
|
5758
|
-
const int nth = params->nth;
|
|
5759
|
-
|
|
5760
|
-
const int nr = ggml_nrows(dst);
|
|
5761
|
-
|
|
5762
|
-
GGML_ASSERT(n_dims <= ne0);
|
|
5763
|
-
GGML_ASSERT(n_dims % 2 == 0);
|
|
5764
|
-
|
|
5765
|
-
// rows per thread
|
|
5766
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5767
|
-
|
|
5768
|
-
// row range for this thread
|
|
5769
|
-
const int ir0 = dr*ith;
|
|
5770
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5771
|
-
|
|
5772
|
-
// row index used to determine which thread to use
|
|
5773
|
-
int ir = 0;
|
|
5774
|
-
|
|
5775
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
5776
|
-
|
|
5777
|
-
float corr_dims[2];
|
|
5778
|
-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5779
|
-
|
|
5780
|
-
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
5781
|
-
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
5782
|
-
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
5783
|
-
|
|
5784
|
-
if (is_mrope) {
|
|
5785
|
-
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5786
|
-
}
|
|
5787
|
-
|
|
5788
|
-
if (is_vision) {
|
|
5789
|
-
GGML_ASSERT(n_dims == ne0/2);
|
|
5790
|
-
}
|
|
5791
|
-
|
|
5792
|
-
const float * freq_factors = NULL;
|
|
5793
|
-
if (src2 != NULL) {
|
|
5794
|
-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
|
5795
|
-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
|
5796
|
-
freq_factors = (const float *) src2->data;
|
|
5797
|
-
}
|
|
5798
|
-
|
|
5799
|
-
// backward process uses inverse rotation by cos and sin.
|
|
5800
|
-
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
5801
|
-
// this essentially just switches the sign of sin.
|
|
5802
|
-
const float sin_sign = forward ? 1.0f : -1.0f;
|
|
5803
|
-
|
|
5804
|
-
const int32_t * pos = (const int32_t *) src1->data;
|
|
5805
|
-
|
|
5806
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
5807
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
5808
|
-
|
|
5809
|
-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5810
|
-
if (!is_mrope) {
|
|
5811
|
-
const int64_t p = pos[i2];
|
|
5812
|
-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5813
|
-
}
|
|
5814
|
-
else {
|
|
5815
|
-
const int64_t p_t = pos[i2];
|
|
5816
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
5817
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5818
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5819
|
-
ggml_mrope_cache_init(
|
|
5820
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5821
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5822
|
-
}
|
|
5823
|
-
|
|
5824
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
5825
|
-
if (ir++ < ir0) continue;
|
|
5826
|
-
if (ir > ir1) break;
|
|
5827
|
-
|
|
5828
|
-
if (is_neox || is_mrope) {
|
|
5829
|
-
if (is_vision) {
|
|
5830
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5831
|
-
const int64_t ic = i0/2;
|
|
5832
|
-
|
|
5833
|
-
const float cos_theta = cache[i0 + 0];
|
|
5834
|
-
const float sin_theta = cache[i0 + 1];
|
|
5835
|
-
|
|
5836
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5837
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5838
|
-
|
|
5839
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5840
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
5841
5839
|
|
|
5842
|
-
|
|
5843
|
-
|
|
5844
|
-
|
|
5845
|
-
|
|
5846
|
-
|
|
5847
|
-
|
|
5848
|
-
|
|
5849
|
-
|
|
5850
|
-
|
|
5851
|
-
|
|
5852
|
-
|
|
5853
|
-
|
|
5854
|
-
|
|
5855
|
-
|
|
5856
|
-
|
|
5857
|
-
|
|
5858
|
-
|
|
5859
|
-
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5860
|
-
}
|
|
5861
|
-
}
|
|
5862
|
-
} else {
|
|
5863
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5864
|
-
const float cos_theta = cache[i0 + 0];
|
|
5865
|
-
const float sin_theta = cache[i0 + 1];
|
|
5866
|
-
|
|
5867
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5868
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5869
|
-
|
|
5870
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5871
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
|
|
5872
|
-
|
|
5873
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5874
|
-
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5875
|
-
}
|
|
5840
|
+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
5841
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
5842
|
+
|
|
5843
|
+
switch (mode) {
|
|
5844
|
+
case GGML_ROPE_TYPE_NORMAL:
|
|
5845
|
+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
|
|
5846
|
+
break;
|
|
5847
|
+
case GGML_ROPE_TYPE_NEOX:
|
|
5848
|
+
case GGML_ROPE_TYPE_MROPE:
|
|
5849
|
+
case GGML_ROPE_TYPE_IMROPE:
|
|
5850
|
+
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
|
|
5851
|
+
break;
|
|
5852
|
+
case GGML_ROPE_TYPE_VISION:
|
|
5853
|
+
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
|
|
5854
|
+
break;
|
|
5855
|
+
default:
|
|
5856
|
+
GGML_ABORT("rope type not supported");
|
|
5876
5857
|
}
|
|
5877
5858
|
|
|
5878
|
-
if (is_vision) {
|
|
5879
|
-
|
|
5880
|
-
const int64_t ic = i0/2;
|
|
5881
|
-
|
|
5882
|
-
const float cos_theta = cache[i0 + 0];
|
|
5883
|
-
const float sin_theta = cache[i0 + 1];
|
|
5884
|
-
|
|
5885
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5886
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5887
|
-
|
|
5888
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5889
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
5890
|
-
|
|
5891
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5892
|
-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5893
|
-
}
|
|
5894
|
-
} else {
|
|
5859
|
+
if (!is_vision) {
|
|
5860
|
+
// fill the remain channels with data from src tensor
|
|
5895
5861
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5896
|
-
const
|
|
5897
|
-
|
|
5862
|
+
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5863
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5898
5864
|
|
|
5899
5865
|
dst_data[0] = src[0];
|
|
5900
5866
|
dst_data[1] = src[1];
|
|
5901
5867
|
}
|
|
5902
5868
|
}
|
|
5903
|
-
}
|
|
5869
|
+
} //attn-heads
|
|
5904
5870
|
}
|
|
5905
5871
|
}
|
|
5906
5872
|
}
|
|
@@ -5914,11 +5880,11 @@ void ggml_compute_forward_rope(
|
|
|
5914
5880
|
switch (src0->type) {
|
|
5915
5881
|
case GGML_TYPE_F16:
|
|
5916
5882
|
{
|
|
5917
|
-
|
|
5883
|
+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
|
|
5918
5884
|
} break;
|
|
5919
5885
|
case GGML_TYPE_F32:
|
|
5920
5886
|
{
|
|
5921
|
-
|
|
5887
|
+
ggml_compute_forward_rope_flt<float>(params, dst, true);
|
|
5922
5888
|
} break;
|
|
5923
5889
|
default:
|
|
5924
5890
|
{
|
|
@@ -5938,11 +5904,11 @@ void ggml_compute_forward_rope_back(
|
|
|
5938
5904
|
switch (src0->type) {
|
|
5939
5905
|
case GGML_TYPE_F16:
|
|
5940
5906
|
{
|
|
5941
|
-
|
|
5907
|
+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
|
|
5942
5908
|
} break;
|
|
5943
5909
|
case GGML_TYPE_F32:
|
|
5944
5910
|
{
|
|
5945
|
-
|
|
5911
|
+
ggml_compute_forward_rope_flt<float>(params, dst, false);
|
|
5946
5912
|
} break;
|
|
5947
5913
|
default:
|
|
5948
5914
|
{
|
|
@@ -6239,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6239
6205
|
const ggml_tensor * src1 = dst->src[1];
|
|
6240
6206
|
|
|
6241
6207
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
6242
|
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6208
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
|
6243
6209
|
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
|
6244
6210
|
|
|
6245
6211
|
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
@@ -6270,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6270
6236
|
int ofs1 = is_2D ? nb12 : nb11;
|
|
6271
6237
|
|
|
6272
6238
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
|
6273
|
-
GGML_ASSERT(nb10 ==
|
|
6239
|
+
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
|
6274
6240
|
|
|
6275
6241
|
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
6276
6242
|
{
|
|
@@ -6283,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6283
6249
|
|
|
6284
6250
|
// micro kernel
|
|
6285
6251
|
ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
6286
|
-
const float * const
|
|
6252
|
+
const float * const src_data_f32 = src1->type == GGML_TYPE_F32
|
|
6253
|
+
? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
|
|
6254
|
+
: nullptr; // [IH, IW]
|
|
6255
|
+
const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
|
|
6256
|
+
? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
|
|
6257
|
+
: nullptr; // [IH, IW]
|
|
6287
6258
|
|
|
6288
6259
|
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
|
|
6289
6260
|
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
@@ -6293,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
6293
6264
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
6294
6265
|
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
|
6295
6266
|
} else {
|
|
6296
|
-
|
|
6267
|
+
if (src_data_f32 != nullptr) {
|
|
6268
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
|
|
6269
|
+
} else {
|
|
6270
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
|
|
6271
|
+
}
|
|
6297
6272
|
}
|
|
6298
6273
|
}
|
|
6299
6274
|
}
|
|
@@ -6493,7 +6468,7 @@ static void ggml_compute_forward_im2col_3d_f16(
|
|
|
6493
6468
|
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6494
6469
|
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6495
6470
|
|
|
6496
|
-
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW
|
|
6471
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
6497
6472
|
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6498
6473
|
} else {
|
|
6499
6474
|
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
@@ -6664,8 +6639,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
|
|
|
6664
6639
|
ggml_compute_forward_mul_mat(params, &dst);
|
|
6665
6640
|
}
|
|
6666
6641
|
|
|
6642
|
+
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
|
6643
|
+
return (coord + size) % size; // adding size avoids negative number weirdness
|
|
6644
|
+
}
|
|
6645
|
+
|
|
6667
6646
|
// ggml_compute_forward_conv_2d
|
|
6668
6647
|
|
|
6648
|
+
|
|
6669
6649
|
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
|
6670
6650
|
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
|
6671
6651
|
const ggml_tensor * src, // [W, H, C, N]
|
|
@@ -7074,7 +7054,11 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
|
|
|
7074
7054
|
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
|
|
7075
7055
|
|
|
7076
7056
|
#ifdef GGML_SIMD
|
|
7077
|
-
|
|
7057
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
7058
|
+
const int64_t pkg_size = svcntw();
|
|
7059
|
+
#else
|
|
7060
|
+
const int64_t pkg_size = GGML_F32_EPR;
|
|
7061
|
+
#endif
|
|
7078
7062
|
const int64_t pkg_count = c / pkg_size;
|
|
7079
7063
|
const int64_t c_pkg_end = pkg_count * pkg_size;
|
|
7080
7064
|
#else
|
|
@@ -7211,12 +7195,13 @@ void ggml_compute_forward_conv_2d_dw(
|
|
|
7211
7195
|
}
|
|
7212
7196
|
}
|
|
7213
7197
|
|
|
7214
|
-
//
|
|
7215
|
-
|
|
7216
|
-
static void ggml_compute_forward_pool_1d_sk_p0(
|
|
7198
|
+
// ggml_compute_forward_pool_1d_ksp
|
|
7199
|
+
static void ggml_compute_forward_pool_1d_ksp(
|
|
7217
7200
|
const ggml_compute_params * params,
|
|
7218
7201
|
const ggml_op_pool op,
|
|
7219
7202
|
const int k,
|
|
7203
|
+
const int s,
|
|
7204
|
+
const int p,
|
|
7220
7205
|
ggml_tensor * dst) {
|
|
7221
7206
|
|
|
7222
7207
|
const ggml_tensor * src = dst->src[0];
|
|
@@ -7227,39 +7212,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|
|
7227
7212
|
return;
|
|
7228
7213
|
}
|
|
7229
7214
|
|
|
7230
|
-
const
|
|
7231
|
-
const
|
|
7232
|
-
float * drow = (float *)dst->data;
|
|
7215
|
+
const int64_t IW = src->ne[0];
|
|
7216
|
+
const int64_t OW = dst->ne[0];
|
|
7233
7217
|
|
|
7234
|
-
const int64_t
|
|
7218
|
+
const int64_t nr = ggml_nrows(src);
|
|
7235
7219
|
|
|
7236
|
-
|
|
7237
|
-
const
|
|
7238
|
-
|
|
7239
|
-
|
|
7220
|
+
for (int64_t ir = 0; ir < nr; ++ir) {
|
|
7221
|
+
const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
|
|
7222
|
+
float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
|
|
7223
|
+
|
|
7224
|
+
for (int64_t ow = 0; ow < OW; ++ow) {
|
|
7225
|
+
float res = 0;
|
|
7240
7226
|
switch (op) {
|
|
7241
|
-
case GGML_OP_POOL_AVG:
|
|
7242
|
-
case GGML_OP_POOL_MAX:
|
|
7227
|
+
case GGML_OP_POOL_AVG: res = 0.0f; break;
|
|
7228
|
+
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
|
7243
7229
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7244
7230
|
}
|
|
7231
|
+
|
|
7232
|
+
int count = 0;
|
|
7233
|
+
const int base = (int) ow * s - p;
|
|
7234
|
+
|
|
7245
7235
|
for (int ki = 0; ki < k; ++ki) {
|
|
7246
|
-
const
|
|
7236
|
+
const int j = base + ki;
|
|
7237
|
+
if (j < 0 || j >= (int) IW) {
|
|
7238
|
+
continue;
|
|
7239
|
+
}
|
|
7240
|
+
|
|
7241
|
+
float v;
|
|
7242
|
+
if (src->type == GGML_TYPE_F32) {
|
|
7243
|
+
v = ((const float *) srow_bytes)[j];
|
|
7244
|
+
} else {
|
|
7245
|
+
v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
|
|
7246
|
+
}
|
|
7247
|
+
|
|
7247
7248
|
switch (op) {
|
|
7248
|
-
case GGML_OP_POOL_AVG:
|
|
7249
|
-
case GGML_OP_POOL_MAX:
|
|
7250
|
-
case GGML_OP_POOL_COUNT:
|
|
7249
|
+
case GGML_OP_POOL_AVG: res += v; break;
|
|
7250
|
+
case GGML_OP_POOL_MAX: res = std::max(v, res); break;
|
|
7251
|
+
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7251
7252
|
}
|
|
7252
|
-
|
|
7253
|
+
|
|
7254
|
+
++count;
|
|
7253
7255
|
}
|
|
7256
|
+
|
|
7254
7257
|
switch (op) {
|
|
7255
|
-
case GGML_OP_POOL_AVG:
|
|
7256
|
-
case GGML_OP_POOL_MAX:
|
|
7258
|
+
case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
|
|
7259
|
+
case GGML_OP_POOL_MAX: break;
|
|
7257
7260
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7258
7261
|
}
|
|
7259
|
-
}
|
|
7260
7262
|
|
|
7261
|
-
|
|
7262
|
-
|
|
7263
|
+
drow[ow] = res;
|
|
7264
|
+
}
|
|
7263
7265
|
}
|
|
7264
7266
|
}
|
|
7265
7267
|
|
|
@@ -7274,10 +7276,8 @@ void ggml_compute_forward_pool_1d(
|
|
|
7274
7276
|
const int k0 = opts[1];
|
|
7275
7277
|
const int s0 = opts[2];
|
|
7276
7278
|
const int p0 = opts[3];
|
|
7277
|
-
GGML_ASSERT(p0 == 0); // padding not supported
|
|
7278
|
-
GGML_ASSERT(k0 == s0); // only s = k supported
|
|
7279
7279
|
|
|
7280
|
-
|
|
7280
|
+
ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
|
|
7281
7281
|
}
|
|
7282
7282
|
|
|
7283
7283
|
// ggml_compute_forward_pool_2d
|
|
@@ -7295,6 +7295,7 @@ void ggml_compute_forward_pool_2d(
|
|
|
7295
7295
|
}
|
|
7296
7296
|
|
|
7297
7297
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
|
7298
|
+
|
|
7298
7299
|
ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
|
7299
7300
|
const int k0 = opts[1];
|
|
7300
7301
|
const int k1 = opts[2];
|
|
@@ -7318,11 +7319,13 @@ void ggml_compute_forward_pool_2d(
|
|
|
7318
7319
|
while (cdata < data_end) {
|
|
7319
7320
|
for (int oy = 0; oy < py; ++oy) {
|
|
7320
7321
|
float * const drow = dplane + oy * px;
|
|
7322
|
+
float * const out = drow;
|
|
7323
|
+
|
|
7321
7324
|
for (int ox = 0; ox < px; ++ox) {
|
|
7322
|
-
float
|
|
7325
|
+
float res = 0;
|
|
7323
7326
|
switch (op) {
|
|
7324
|
-
case GGML_OP_POOL_AVG:
|
|
7325
|
-
case GGML_OP_POOL_MAX:
|
|
7327
|
+
case GGML_OP_POOL_AVG: res = 0; break;
|
|
7328
|
+
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
|
7326
7329
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7327
7330
|
}
|
|
7328
7331
|
|
|
@@ -7330,24 +7333,32 @@ void ggml_compute_forward_pool_2d(
|
|
|
7330
7333
|
const int iy = offset1 + oy * s1;
|
|
7331
7334
|
|
|
7332
7335
|
for (int ky = 0; ky < k1; ++ky) {
|
|
7333
|
-
if (iy + ky < 0 || iy + ky >= src->ne[1])
|
|
7336
|
+
if (iy + ky < 0 || iy + ky >= src->ne[1]) {
|
|
7337
|
+
continue;
|
|
7338
|
+
}
|
|
7339
|
+
|
|
7334
7340
|
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
|
|
7335
7341
|
for (int kx = 0; kx < k0; ++kx) {
|
|
7336
7342
|
int j = ix + kx;
|
|
7337
|
-
if (j < 0 || j >= src->ne[0])
|
|
7343
|
+
if (j < 0 || j >= src->ne[0]) {
|
|
7344
|
+
continue;
|
|
7345
|
+
}
|
|
7346
|
+
|
|
7338
7347
|
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
|
7339
7348
|
switch (op) {
|
|
7340
|
-
case GGML_OP_POOL_AVG:
|
|
7341
|
-
case GGML_OP_POOL_MAX:
|
|
7349
|
+
case GGML_OP_POOL_AVG: res += srow_j; break;
|
|
7350
|
+
case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
|
|
7342
7351
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7343
7352
|
}
|
|
7344
7353
|
}
|
|
7345
7354
|
}
|
|
7346
7355
|
switch (op) {
|
|
7347
|
-
case GGML_OP_POOL_AVG:
|
|
7348
|
-
case GGML_OP_POOL_MAX:
|
|
7356
|
+
case GGML_OP_POOL_AVG: res /= ka; break;
|
|
7357
|
+
case GGML_OP_POOL_MAX: break;
|
|
7349
7358
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
7350
7359
|
}
|
|
7360
|
+
|
|
7361
|
+
out[ox] = res;
|
|
7351
7362
|
}
|
|
7352
7363
|
}
|
|
7353
7364
|
|
|
@@ -7497,10 +7508,17 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7497
7508
|
float sf1 = (float)ne1/src0->ne[1];
|
|
7498
7509
|
float sf2 = (float)ne2/src0->ne[2];
|
|
7499
7510
|
float sf3 = (float)ne3/src0->ne[3];
|
|
7511
|
+
float pixel_offset = 0.5f;
|
|
7500
7512
|
|
|
7501
7513
|
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
|
7502
7514
|
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
7503
7515
|
|
|
7516
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7517
|
+
pixel_offset = 0.0f;
|
|
7518
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
7519
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
7520
|
+
}
|
|
7521
|
+
|
|
7504
7522
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
7505
7523
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7506
7524
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7519,14 +7537,66 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7519
7537
|
}
|
|
7520
7538
|
}
|
|
7521
7539
|
}
|
|
7522
|
-
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7523
|
-
|
|
7524
|
-
|
|
7525
|
-
|
|
7526
|
-
|
|
7527
|
-
|
|
7528
|
-
|
|
7540
|
+
} else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
|
|
7541
|
+
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
|
7542
|
+
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
|
7543
|
+
auto triangle_filter = [](float x) -> float {
|
|
7544
|
+
return std::max(1.0f - fabsf(x), 0.0f);
|
|
7545
|
+
};
|
|
7546
|
+
|
|
7547
|
+
// support and invscale, minimum 1 pixel for bilinear
|
|
7548
|
+
const float support1 = std::max(1.0f, 1.0f / sf1);
|
|
7549
|
+
const float invscale1 = 1.0f / support1;
|
|
7550
|
+
const float support0 = std::max(1.0f, 1.0f / sf0);
|
|
7551
|
+
const float invscale0 = 1.0f / support0;
|
|
7552
|
+
|
|
7553
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7554
|
+
const int64_t i03 = i3 / sf3;
|
|
7555
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7556
|
+
const int64_t i02 = i2 / sf2;
|
|
7557
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7558
|
+
const float y = ((float) i1 + pixel_offset) / sf1;
|
|
7559
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7560
|
+
const float x = ((float) i0 + pixel_offset) / sf0;
|
|
7561
|
+
|
|
7562
|
+
// the range of source pixels that contribute
|
|
7563
|
+
const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
|
|
7564
|
+
const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
|
|
7565
|
+
const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
|
|
7566
|
+
const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
|
|
7567
|
+
|
|
7568
|
+
// bilinear filter with antialiasing
|
|
7569
|
+
float val = 0.0f;
|
|
7570
|
+
float total_weight = 0.0f;
|
|
7571
|
+
|
|
7572
|
+
for (int64_t sy = y_min; sy < y_max; sy++) {
|
|
7573
|
+
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
|
7574
|
+
|
|
7575
|
+
for (int64_t sx = x_min; sx < x_max; sx++) {
|
|
7576
|
+
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
|
7577
|
+
const float weight = weight_x * weight_y;
|
|
7578
|
+
|
|
7579
|
+
if (weight <= 0.0f) {
|
|
7580
|
+
continue;
|
|
7581
|
+
}
|
|
7582
|
+
|
|
7583
|
+
const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
|
|
7584
|
+
val += pixel * weight;
|
|
7585
|
+
total_weight += weight;
|
|
7586
|
+
}
|
|
7587
|
+
}
|
|
7529
7588
|
|
|
7589
|
+
if (total_weight > 0.0f) {
|
|
7590
|
+
val /= total_weight;
|
|
7591
|
+
}
|
|
7592
|
+
|
|
7593
|
+
float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7594
|
+
*dst_ptr = val;
|
|
7595
|
+
}
|
|
7596
|
+
}
|
|
7597
|
+
}
|
|
7598
|
+
}
|
|
7599
|
+
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7530
7600
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7531
7601
|
const int64_t i03 = i3 / sf3;
|
|
7532
7602
|
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
@@ -7561,6 +7631,51 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7561
7631
|
|
|
7562
7632
|
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
|
|
7563
7633
|
|
|
7634
|
+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7635
|
+
*y_dst = val;
|
|
7636
|
+
}
|
|
7637
|
+
}
|
|
7638
|
+
}
|
|
7639
|
+
}
|
|
7640
|
+
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
|
7641
|
+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
|
7642
|
+
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
|
7643
|
+
auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
|
|
7644
|
+
auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
|
|
7645
|
+
auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
|
|
7646
|
+
const float w0 = weight2(x + 1);
|
|
7647
|
+
const float w1 = weight1(x + 0);
|
|
7648
|
+
const float w2 = weight1(1 - x);
|
|
7649
|
+
const float w3 = weight2(2 - x);
|
|
7650
|
+
return p0*w0 + p1*w1 + p2*w2 + p3*w3;
|
|
7651
|
+
};
|
|
7652
|
+
|
|
7653
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7654
|
+
const int64_t i03 = i3 / sf3;
|
|
7655
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7656
|
+
const int64_t i02 = i2 / sf2;
|
|
7657
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7658
|
+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
|
|
7659
|
+
const int64_t y0 = (int64_t)floorf(y);
|
|
7660
|
+
const float dy = y - (float)y0;
|
|
7661
|
+
|
|
7662
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7663
|
+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
|
|
7664
|
+
const int64_t x0 = (int64_t)floorf(x);
|
|
7665
|
+
const float dx = x - (float)x0;
|
|
7666
|
+
|
|
7667
|
+
auto p = [=](int64_t x_off, int64_t y_off) -> float {
|
|
7668
|
+
int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
|
|
7669
|
+
int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
|
|
7670
|
+
return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
7671
|
+
};
|
|
7672
|
+
|
|
7673
|
+
const float val = bicubic(
|
|
7674
|
+
bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
|
|
7675
|
+
bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
|
|
7676
|
+
bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
|
|
7677
|
+
bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
|
|
7678
|
+
|
|
7564
7679
|
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7565
7680
|
*y_dst = val;
|
|
7566
7681
|
}
|
|
@@ -7593,14 +7708,14 @@ void ggml_compute_forward_upscale(
|
|
|
7593
7708
|
|
|
7594
7709
|
// ggml_compute_forward_pad
|
|
7595
7710
|
|
|
7711
|
+
template<bool circular_t>
|
|
7596
7712
|
static void ggml_compute_forward_pad_f32(
|
|
7597
7713
|
const ggml_compute_params * params,
|
|
7598
7714
|
ggml_tensor * dst) {
|
|
7599
7715
|
|
|
7600
7716
|
const ggml_tensor * src0 = dst->src[0];
|
|
7601
7717
|
|
|
7602
|
-
|
|
7603
|
-
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
|
7718
|
+
assert(dst->nb[0] == sizeof(float));
|
|
7604
7719
|
|
|
7605
7720
|
const int ith = params->ith;
|
|
7606
7721
|
const int nth = params->nth;
|
|
@@ -7617,23 +7732,40 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7617
7732
|
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
|
7618
7733
|
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
|
7619
7734
|
|
|
7620
|
-
|
|
7621
7735
|
// TODO: optimize
|
|
7622
7736
|
|
|
7623
7737
|
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
|
7624
7738
|
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
|
7625
7739
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
|
7626
7740
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
7627
|
-
|
|
7628
|
-
if (
|
|
7629
|
-
|
|
7630
|
-
|
|
7631
|
-
|
|
7632
|
-
const int64_t
|
|
7741
|
+
// circular means wrap around on a torus, so x and y loop around
|
|
7742
|
+
if constexpr (circular_t) {
|
|
7743
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7744
|
+
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
|
|
7745
|
+
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
|
|
7746
|
+
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
|
|
7747
|
+
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
|
|
7748
|
+
|
|
7749
|
+
const int64_t src_idx =
|
|
7750
|
+
src_i3*nb03 +
|
|
7751
|
+
src_i2*nb02 +
|
|
7752
|
+
src_i1*nb01 +
|
|
7753
|
+
src_i0*nb00;
|
|
7754
|
+
|
|
7633
7755
|
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7634
7756
|
dst_ptr[dst_idx] = *src_ptr;
|
|
7635
7757
|
} else {
|
|
7636
|
-
|
|
7758
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7759
|
+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
|
7760
|
+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
|
7761
|
+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
|
7762
|
+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
|
7763
|
+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
|
7764
|
+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7765
|
+
dst_ptr[dst_idx] = *src_ptr;
|
|
7766
|
+
} else {
|
|
7767
|
+
dst_ptr[dst_idx] = 0;
|
|
7768
|
+
}
|
|
7637
7769
|
}
|
|
7638
7770
|
}
|
|
7639
7771
|
}
|
|
@@ -7641,16 +7773,20 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7641
7773
|
}
|
|
7642
7774
|
}
|
|
7643
7775
|
|
|
7776
|
+
|
|
7644
7777
|
void ggml_compute_forward_pad(
|
|
7645
7778
|
const ggml_compute_params * params,
|
|
7646
7779
|
ggml_tensor * dst) {
|
|
7647
|
-
|
|
7648
7780
|
const ggml_tensor * src0 = dst->src[0];
|
|
7649
|
-
|
|
7781
|
+
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
|
|
7650
7782
|
switch (src0->type) {
|
|
7651
7783
|
case GGML_TYPE_F32:
|
|
7652
7784
|
{
|
|
7653
|
-
|
|
7785
|
+
if (circular) {
|
|
7786
|
+
ggml_compute_forward_pad_f32<true>(params, dst);
|
|
7787
|
+
} else {
|
|
7788
|
+
ggml_compute_forward_pad_f32<false>(params, dst);
|
|
7789
|
+
}
|
|
7654
7790
|
} break;
|
|
7655
7791
|
default:
|
|
7656
7792
|
{
|
|
@@ -7854,6 +7990,18 @@ void ggml_compute_forward_timestep_embedding(
|
|
|
7854
7990
|
|
|
7855
7991
|
// ggml_compute_forward_argsort
|
|
7856
7992
|
|
|
7993
|
+
template<enum ggml_sort_order order>
|
|
7994
|
+
struct cmp_argsort {
|
|
7995
|
+
const float * data;
|
|
7996
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7997
|
+
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
|
7998
|
+
return data[a] < data[b];
|
|
7999
|
+
} else {
|
|
8000
|
+
return data[a] > data[b];
|
|
8001
|
+
}
|
|
8002
|
+
}
|
|
8003
|
+
};
|
|
8004
|
+
|
|
7857
8005
|
static void ggml_compute_forward_argsort_f32(
|
|
7858
8006
|
const ggml_compute_params * params,
|
|
7859
8007
|
ggml_tensor * dst) {
|
|
@@ -7872,23 +8020,25 @@ static void ggml_compute_forward_argsort_f32(
|
|
|
7872
8020
|
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
|
7873
8021
|
|
|
7874
8022
|
for (int64_t i = ith; i < nr; i += nth) {
|
|
7875
|
-
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7876
8023
|
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
7877
8024
|
|
|
8025
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
8026
|
+
|
|
7878
8027
|
for (int64_t j = 0; j < ne0; j++) {
|
|
7879
8028
|
dst_data[j] = j;
|
|
7880
8029
|
}
|
|
7881
8030
|
|
|
7882
|
-
|
|
7883
|
-
|
|
7884
|
-
|
|
7885
|
-
|
|
7886
|
-
|
|
7887
|
-
|
|
7888
|
-
|
|
7889
|
-
|
|
7890
|
-
|
|
7891
|
-
|
|
8031
|
+
switch (order) {
|
|
8032
|
+
case GGML_SORT_ORDER_ASC:
|
|
8033
|
+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
|
|
8034
|
+
break;
|
|
8035
|
+
|
|
8036
|
+
case GGML_SORT_ORDER_DESC:
|
|
8037
|
+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
|
|
8038
|
+
break;
|
|
8039
|
+
|
|
8040
|
+
default:
|
|
8041
|
+
GGML_ABORT("invalid sort order");
|
|
7892
8042
|
}
|
|
7893
8043
|
}
|
|
7894
8044
|
}
|
|
@@ -7911,12 +8061,80 @@ void ggml_compute_forward_argsort(
|
|
|
7911
8061
|
}
|
|
7912
8062
|
}
|
|
7913
8063
|
|
|
7914
|
-
//
|
|
8064
|
+
// ggml_compute_forward_top_k
|
|
7915
8065
|
|
|
7916
|
-
|
|
8066
|
+
struct cmp_top_k {
|
|
8067
|
+
const float * data;
|
|
8068
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
8069
|
+
return data[a] > data[b];
|
|
8070
|
+
}
|
|
8071
|
+
};
|
|
8072
|
+
|
|
8073
|
+
static void ggml_compute_forward_top_k_f32(
|
|
8074
|
+
const ggml_compute_params * params,
|
|
8075
|
+
ggml_tensor * dst) {
|
|
8076
|
+
|
|
8077
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
8078
|
+
|
|
8079
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
8080
|
+
|
|
8081
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8082
|
+
|
|
8083
|
+
const int ith = params->ith;
|
|
8084
|
+
const int nth = params->nth;
|
|
8085
|
+
|
|
8086
|
+
const int64_t nr = ggml_nrows(src0);
|
|
8087
|
+
|
|
8088
|
+
const int top_k = ne0;
|
|
8089
|
+
|
|
8090
|
+
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
8091
|
+
|
|
8092
|
+
for (int64_t i = ith; i < nr; i += nth) {
|
|
8093
|
+
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
8094
|
+
|
|
8095
|
+
for (int64_t j = 0; j < ne00; j++) {
|
|
8096
|
+
tmp[j] = j;
|
|
8097
|
+
}
|
|
8098
|
+
|
|
8099
|
+
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
|
|
8100
|
+
|
|
8101
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
8102
|
+
|
|
8103
|
+
std::copy(tmp, tmp + top_k, dst_data);
|
|
8104
|
+
|
|
8105
|
+
// emphasize that the order is not important
|
|
8106
|
+
if (top_k > 1) {
|
|
8107
|
+
std::swap(dst_data[0], dst_data[1]);
|
|
8108
|
+
}
|
|
8109
|
+
}
|
|
8110
|
+
}
|
|
8111
|
+
|
|
8112
|
+
void ggml_compute_forward_top_k(
|
|
8113
|
+
const ggml_compute_params * params,
|
|
8114
|
+
ggml_tensor * dst) {
|
|
8115
|
+
|
|
8116
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
8117
|
+
|
|
8118
|
+
switch (src0->type) {
|
|
8119
|
+
case GGML_TYPE_F32:
|
|
8120
|
+
{
|
|
8121
|
+
ggml_compute_forward_top_k_f32(params, dst);
|
|
8122
|
+
} break;
|
|
8123
|
+
default:
|
|
8124
|
+
{
|
|
8125
|
+
GGML_ABORT("fatal error");
|
|
8126
|
+
}
|
|
8127
|
+
}
|
|
8128
|
+
}
|
|
8129
|
+
|
|
8130
|
+
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
7917
8131
|
const ggml_compute_params * params,
|
|
7918
|
-
ggml_tensor * dst
|
|
8132
|
+
ggml_tensor * dst,
|
|
8133
|
+
int ir0, int ir1,
|
|
8134
|
+
int64_t ic_start, int64_t ic_end,
|
|
8135
|
+
float * partials, int64_t partial_stride) {
|
|
7919
8136
|
|
|
8137
|
+
const bool write_partials = (partials != nullptr);
|
|
7920
8138
|
const ggml_tensor * q = dst->src[0];
|
|
7921
8139
|
const ggml_tensor * k = dst->src[1];
|
|
7922
8140
|
const ggml_tensor * v = dst->src[2];
|
|
@@ -7932,9 +8150,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7932
8150
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
7933
8151
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
7934
8152
|
|
|
7935
|
-
const int ith = params->ith;
|
|
7936
|
-
const int nth = params->nth;
|
|
7937
|
-
|
|
7938
8153
|
const int64_t DK = nek0;
|
|
7939
8154
|
const int64_t DV = nev0;
|
|
7940
8155
|
const int64_t N = neq1;
|
|
@@ -7968,16 +8183,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7968
8183
|
|
|
7969
8184
|
// parallelize by q rows using ggml_vec_dot_f32
|
|
7970
8185
|
|
|
7971
|
-
// total rows in q
|
|
7972
|
-
const int nr = neq1*neq2*neq3;
|
|
7973
|
-
|
|
7974
|
-
// rows per thread
|
|
7975
|
-
const int dr = (nr + nth - 1)/nth;
|
|
7976
|
-
|
|
7977
|
-
// row range for this thread
|
|
7978
|
-
const int ir0 = dr*ith;
|
|
7979
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
7980
|
-
|
|
7981
8186
|
float scale = 1.0f;
|
|
7982
8187
|
float max_bias = 0.0f;
|
|
7983
8188
|
float logit_softcap = 0.0f;
|
|
@@ -8004,7 +8209,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8004
8209
|
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
8005
8210
|
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
8006
8211
|
|
|
8007
|
-
|
|
8212
|
+
int ith = params->ith;
|
|
8213
|
+
|
|
8008
8214
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
8009
8215
|
// q indices
|
|
8010
8216
|
const int iq3 = ir/(neq2*neq1);
|
|
@@ -8044,7 +8250,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8044
8250
|
// online softmax / attention
|
|
8045
8251
|
// loop over n_kv and n_head_kv
|
|
8046
8252
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
|
8047
|
-
|
|
8253
|
+
|
|
8254
|
+
for (int64_t ic = ic_start; ic < ic_end; ++ic) {
|
|
8048
8255
|
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
|
8049
8256
|
if (mv == -INFINITY) {
|
|
8050
8257
|
continue;
|
|
@@ -8117,8 +8324,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8117
8324
|
}
|
|
8118
8325
|
}
|
|
8119
8326
|
|
|
8120
|
-
// sinks
|
|
8121
|
-
if (sinks) {
|
|
8327
|
+
// sinks - apply only on the first kv-chunk
|
|
8328
|
+
if (sinks && ic_start == 0) {
|
|
8122
8329
|
const float s = ((float *)((char *) sinks->data))[h];
|
|
8123
8330
|
|
|
8124
8331
|
float ms = 1.0f;
|
|
@@ -8126,6 +8333,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8126
8333
|
|
|
8127
8334
|
if (s > M) {
|
|
8128
8335
|
ms = expf(M - s);
|
|
8336
|
+
M = s;
|
|
8129
8337
|
ggml_vec_scale_f32(DV, VKQ32, ms);
|
|
8130
8338
|
} else {
|
|
8131
8339
|
vs = expf(s - M);
|
|
@@ -8134,20 +8342,517 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8134
8342
|
S = S*ms + vs;
|
|
8135
8343
|
}
|
|
8136
8344
|
|
|
8137
|
-
|
|
8138
|
-
|
|
8139
|
-
|
|
8345
|
+
if (write_partials) {
|
|
8346
|
+
// Write M, S, VKQ to partials for later reduction
|
|
8347
|
+
// partials layout: [M, S, VKQ[DV]] per query head
|
|
8348
|
+
float * partial = partials + ir * partial_stride;
|
|
8349
|
+
partial[0] = M;
|
|
8350
|
+
partial[1] = S;
|
|
8351
|
+
memcpy(partial + 2, VKQ32, DV * sizeof(float));
|
|
8352
|
+
} else {
|
|
8353
|
+
// V /= S
|
|
8354
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
8355
|
+
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
8140
8356
|
|
|
8141
|
-
|
|
8142
|
-
|
|
8143
|
-
|
|
8144
|
-
|
|
8357
|
+
// dst indices
|
|
8358
|
+
const int i1 = iq1;
|
|
8359
|
+
const int i2 = iq2;
|
|
8360
|
+
const int i3 = iq3;
|
|
8361
|
+
|
|
8362
|
+
// permute(0, 2, 1, 3)
|
|
8363
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
|
8364
|
+
}
|
|
8365
|
+
}
|
|
8366
|
+
}
|
|
8367
|
+
|
|
8368
|
+
static void ggml_compute_forward_flash_attn_ext_tiled(
|
|
8369
|
+
const ggml_compute_params * params,
|
|
8370
|
+
ggml_tensor * dst,
|
|
8371
|
+
int ir0, int ir1) {
|
|
8372
|
+
const ggml_tensor * q = dst->src[0];
|
|
8373
|
+
const ggml_tensor * k = dst->src[1];
|
|
8374
|
+
const ggml_tensor * v = dst->src[2];
|
|
8375
|
+
const ggml_tensor * mask = dst->src[3];
|
|
8376
|
+
const ggml_tensor * sinks = dst->src[4];
|
|
8377
|
+
|
|
8378
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8379
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8380
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8381
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8382
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8383
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8384
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8385
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8386
|
+
|
|
8387
|
+
const int64_t DK = nek0;
|
|
8388
|
+
const int64_t DV = nev0;
|
|
8389
|
+
const int64_t N = neq1;
|
|
8390
|
+
|
|
8391
|
+
GGML_ASSERT(ne0 == DV);
|
|
8392
|
+
GGML_ASSERT(ne2 == N);
|
|
8393
|
+
|
|
8394
|
+
// input tensor rows must be contiguous
|
|
8395
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
8396
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
8397
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
8398
|
+
|
|
8399
|
+
GGML_ASSERT(neq0 == DK);
|
|
8400
|
+
GGML_ASSERT(nek0 == DK);
|
|
8401
|
+
GGML_ASSERT(nev0 == DV);
|
|
8402
|
+
|
|
8403
|
+
GGML_ASSERT(neq1 == N);
|
|
8404
|
+
|
|
8405
|
+
// dst cannot be transposed or permuted
|
|
8406
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8407
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
8408
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
8409
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
8410
|
+
|
|
8411
|
+
GGML_ASSERT(k->type == v->type);
|
|
8412
|
+
const ggml_type kv_type = k->type;
|
|
8145
8413
|
|
|
8146
|
-
// original
|
|
8147
|
-
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
|
8148
8414
|
|
|
8149
|
-
|
|
8150
|
-
|
|
8415
|
+
// broadcast factors
|
|
8416
|
+
const int64_t rk2 = neq2/nek2;
|
|
8417
|
+
const int64_t rk3 = neq3/nek3;
|
|
8418
|
+
|
|
8419
|
+
const int64_t rv2 = neq2/nev2;
|
|
8420
|
+
const int64_t rv3 = neq3/nev3;
|
|
8421
|
+
|
|
8422
|
+
float scale = 1.0f;
|
|
8423
|
+
float max_bias = 0.0f;
|
|
8424
|
+
float logit_softcap = 0.0f;
|
|
8425
|
+
|
|
8426
|
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
8427
|
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
8428
|
+
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
|
8429
|
+
|
|
8430
|
+
if (logit_softcap != 0) {
|
|
8431
|
+
scale /= logit_softcap;
|
|
8432
|
+
}
|
|
8433
|
+
|
|
8434
|
+
const uint32_t n_head = neq2;
|
|
8435
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
|
8436
|
+
|
|
8437
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
8438
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
8439
|
+
|
|
8440
|
+
int ith = params->ith;
|
|
8441
|
+
|
|
8442
|
+
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
8443
|
+
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
|
8444
|
+
|
|
8445
|
+
int ir = ir0;
|
|
8446
|
+
while (ir < ir1) {
|
|
8447
|
+
// q indices for the start of this tile
|
|
8448
|
+
const int iq3 = ir/(neq2*neq1);
|
|
8449
|
+
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
|
8450
|
+
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
|
8451
|
+
|
|
8452
|
+
// Number of valid rows in this tile:
|
|
8453
|
+
// - limited by tile size (Q_TILE_SZ)
|
|
8454
|
+
// - limited by chunk boundary (ir1 - ir)
|
|
8455
|
+
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
|
|
8456
|
+
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
|
|
8457
|
+
GGML_ASSERT(tile_rows > 0);
|
|
8458
|
+
|
|
8459
|
+
const uint32_t h = iq2; // head index
|
|
8460
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
8461
|
+
|
|
8462
|
+
float S[Q_TILE_SZ];
|
|
8463
|
+
float M[Q_TILE_SZ];
|
|
8464
|
+
|
|
8465
|
+
for (int i = 0 ; i < Q_TILE_SZ; ++i) {
|
|
8466
|
+
S[i] = 0.;
|
|
8467
|
+
M[i] = -INFINITY;
|
|
8468
|
+
}
|
|
8469
|
+
|
|
8470
|
+
// Per-thread scratch layout:
|
|
8471
|
+
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
|
|
8472
|
+
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
|
8473
|
+
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
|
8474
|
+
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
|
8475
|
+
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
|
|
8476
|
+
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
|
|
8477
|
+
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
|
|
8478
|
+
|
|
8479
|
+
void * Q_q = base;
|
|
8480
|
+
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
|
8481
|
+
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
|
8482
|
+
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
|
8483
|
+
float * V32 = VKQ32 + Q_TILE_SZ * DV;
|
|
8484
|
+
float * K_f32 = V32 + KV_TILE_SZ * DV;
|
|
8485
|
+
|
|
8486
|
+
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
|
8487
|
+
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
|
8488
|
+
|
|
8489
|
+
// k indices
|
|
8490
|
+
const int ik3 = iq3 / rk3;
|
|
8491
|
+
const int ik2 = iq2 / rk2;
|
|
8492
|
+
|
|
8493
|
+
// v indices
|
|
8494
|
+
const int iv3 = iq3 / rv3;
|
|
8495
|
+
const int iv2 = iq2 / rv2;
|
|
8496
|
+
|
|
8497
|
+
{
|
|
8498
|
+
float * Q_f32 = (float *)Q_q;
|
|
8499
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8500
|
+
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
|
8501
|
+
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
|
|
8502
|
+
}
|
|
8503
|
+
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
|
8504
|
+
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
|
|
8505
|
+
}
|
|
8506
|
+
}
|
|
8507
|
+
|
|
8508
|
+
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
|
|
8509
|
+
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
|
|
8510
|
+
|
|
8511
|
+
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
|
8512
|
+
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
|
|
8513
|
+
|
|
8514
|
+
// skip the tile entirely if all the masks are -inf
|
|
8515
|
+
if (mask) {
|
|
8516
|
+
bool can_skip = true;
|
|
8517
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8518
|
+
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
|
8519
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8520
|
+
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
|
8521
|
+
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
|
8522
|
+
can_skip = false;
|
|
8523
|
+
}
|
|
8524
|
+
}
|
|
8525
|
+
// Pad remaining mask entries with -inf
|
|
8526
|
+
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
|
8527
|
+
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
|
|
8528
|
+
}
|
|
8529
|
+
}
|
|
8530
|
+
|
|
8531
|
+
if (can_skip) {
|
|
8532
|
+
continue;
|
|
8533
|
+
}
|
|
8534
|
+
}
|
|
8535
|
+
|
|
8536
|
+
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
|
|
8537
|
+
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
|
|
8538
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8539
|
+
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
|
|
8540
|
+
if (kv_type == GGML_TYPE_F16) {
|
|
8541
|
+
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
|
|
8542
|
+
for (int64_t dk = 0; dk < DK; dk++) {
|
|
8543
|
+
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
|
|
8544
|
+
}
|
|
8545
|
+
} else {
|
|
8546
|
+
const float * k_f32_src = (const float *)k_data;
|
|
8547
|
+
for (int64_t dk = 0; dk < DK; dk++) {
|
|
8548
|
+
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
|
|
8549
|
+
}
|
|
8550
|
+
}
|
|
8551
|
+
}
|
|
8552
|
+
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
|
8553
|
+
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
|
|
8554
|
+
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
|
|
8555
|
+
|
|
8556
|
+
// Set padded KQ entries to -inf so softmax gives them zero weight
|
|
8557
|
+
if (kv_tile < KV_TILE_SZ) {
|
|
8558
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8559
|
+
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
|
8560
|
+
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
|
|
8561
|
+
}
|
|
8562
|
+
}
|
|
8563
|
+
}
|
|
8564
|
+
|
|
8565
|
+
if (logit_softcap != 0.0f) {
|
|
8566
|
+
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
|
|
8567
|
+
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
|
|
8568
|
+
}
|
|
8569
|
+
|
|
8570
|
+
if (mask) {
|
|
8571
|
+
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
|
|
8572
|
+
}
|
|
8573
|
+
|
|
8574
|
+
bool skip[Q_TILE_SZ] = {};
|
|
8575
|
+
|
|
8576
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8577
|
+
float * kq_row = KQ + tq * KV_TILE_SZ;
|
|
8578
|
+
|
|
8579
|
+
float tile_max;
|
|
8580
|
+
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
|
|
8581
|
+
|
|
8582
|
+
if (tile_max == -INFINITY) {
|
|
8583
|
+
skip[tq] = true;
|
|
8584
|
+
continue;
|
|
8585
|
+
}
|
|
8586
|
+
|
|
8587
|
+
const float Mold = M[tq];
|
|
8588
|
+
const float Mnew = fmaxf(Mold, tile_max);
|
|
8589
|
+
|
|
8590
|
+
if (Mnew > Mold) {
|
|
8591
|
+
const float ms = expf(Mold - Mnew);
|
|
8592
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
|
8593
|
+
S[tq] *= ms;
|
|
8594
|
+
}
|
|
8595
|
+
M[tq] = Mnew;
|
|
8596
|
+
|
|
8597
|
+
|
|
8598
|
+
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
|
8599
|
+
}
|
|
8600
|
+
|
|
8601
|
+
// V accumulation: VKQ32 += softmax(KQ) * V
|
|
8602
|
+
// Pack V tile to contiguous F32, zero-padded
|
|
8603
|
+
for (int tk = 0; tk < kv_tile; tk++) {
|
|
8604
|
+
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
|
|
8605
|
+
if (kv_type == GGML_TYPE_F16) {
|
|
8606
|
+
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
|
|
8607
|
+
} else {
|
|
8608
|
+
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
|
|
8609
|
+
}
|
|
8610
|
+
}
|
|
8611
|
+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
|
8612
|
+
if (skip[tq]) {
|
|
8613
|
+
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
|
|
8614
|
+
}
|
|
8615
|
+
}
|
|
8616
|
+
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
|
|
8617
|
+
}
|
|
8618
|
+
|
|
8619
|
+
// sinks (apply only to valid rows in the tile)
|
|
8620
|
+
if (sinks) {
|
|
8621
|
+
const float s = ((float *)((char *) sinks->data))[h];
|
|
8622
|
+
|
|
8623
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8624
|
+
float ms = 1.0f;
|
|
8625
|
+
float vs = 1.0f;
|
|
8626
|
+
|
|
8627
|
+
if (s > M[tq]) {
|
|
8628
|
+
ms = expf(M[tq] - s);
|
|
8629
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
|
8630
|
+
} else {
|
|
8631
|
+
vs = expf(s - M[tq]);
|
|
8632
|
+
}
|
|
8633
|
+
|
|
8634
|
+
S[tq] = S[tq] * ms + vs;
|
|
8635
|
+
}
|
|
8636
|
+
}
|
|
8637
|
+
|
|
8638
|
+
for (int tq = 0; tq < tile_rows; tq++) {
|
|
8639
|
+
// V /= S
|
|
8640
|
+
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
|
|
8641
|
+
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
|
|
8642
|
+
|
|
8643
|
+
// dst indices
|
|
8644
|
+
const int i1 = iq1 + tq;
|
|
8645
|
+
const int i2 = iq2;
|
|
8646
|
+
const int i3 = iq3;
|
|
8647
|
+
|
|
8648
|
+
// permute(0, 2, 1, 3)
|
|
8649
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
|
|
8650
|
+
}
|
|
8651
|
+
|
|
8652
|
+
ir += tile_rows;
|
|
8653
|
+
}
|
|
8654
|
+
}
|
|
8655
|
+
|
|
8656
|
+
// Reduction function: combines partial results across KV chunks
|
|
8657
|
+
// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
|
|
8658
|
+
static void ggml_flash_attn_ext_reduce_partials(
|
|
8659
|
+
const ggml_compute_params * params,
|
|
8660
|
+
ggml_tensor * dst,
|
|
8661
|
+
const int64_t n_chunks,
|
|
8662
|
+
const int64_t chunk_size) {
|
|
8663
|
+
|
|
8664
|
+
const ggml_tensor * q = dst->src[0];
|
|
8665
|
+
const ggml_tensor * k = dst->src[1];
|
|
8666
|
+
const ggml_tensor * v = dst->src[2];
|
|
8667
|
+
|
|
8668
|
+
const int64_t DK = k->ne[0];
|
|
8669
|
+
const int64_t DV = v->ne[0];
|
|
8670
|
+
const int64_t nek1 = k->ne[1];
|
|
8671
|
+
const int64_t n_q_heads = q->ne[2];
|
|
8672
|
+
|
|
8673
|
+
const int ith = params->ith;
|
|
8674
|
+
const int nth = params->nth;
|
|
8675
|
+
|
|
8676
|
+
const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
|
|
8677
|
+
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
|
|
8678
|
+
|
|
8679
|
+
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
|
|
8680
|
+
const int64_t partial_size = 2 + DV;
|
|
8681
|
+
const float * partials_base = (const float *) params->wdata + partials_offset;
|
|
8682
|
+
|
|
8683
|
+
// Output layout
|
|
8684
|
+
const int64_t ne1 = dst->ne[1];
|
|
8685
|
+
const int64_t ne2 = dst->ne[2];
|
|
8686
|
+
const size_t nb1 = dst->nb[1];
|
|
8687
|
+
|
|
8688
|
+
// Each thread reduces a subset of query heads
|
|
8689
|
+
for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
|
|
8690
|
+
float M_final = -INFINITY;
|
|
8691
|
+
float S_final = 0.0f;
|
|
8692
|
+
float * VKQ_final = thread_wdata;
|
|
8693
|
+
memset(VKQ_final, 0, DV * sizeof(float));
|
|
8694
|
+
|
|
8695
|
+
// Combine partials from all chunks
|
|
8696
|
+
for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
|
|
8697
|
+
const int64_t ic_start = chunk_idx * chunk_size;
|
|
8698
|
+
if (ic_start >= nek1) continue;
|
|
8699
|
+
|
|
8700
|
+
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
|
|
8701
|
+
const float M_chunk = partial[0];
|
|
8702
|
+
const float S_chunk = partial[1];
|
|
8703
|
+
const float * VKQ_chunk = partial + 2;
|
|
8704
|
+
|
|
8705
|
+
if (S_chunk == 0.0f) continue;
|
|
8706
|
+
|
|
8707
|
+
const float M_new = fmaxf(M_final, M_chunk);
|
|
8708
|
+
const float scale_old = expf(M_final - M_new);
|
|
8709
|
+
const float scale_new = expf(M_chunk - M_new);
|
|
8710
|
+
|
|
8711
|
+
for (int64_t d = 0; d < DV; ++d) {
|
|
8712
|
+
VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
|
|
8713
|
+
}
|
|
8714
|
+
S_final = S_final * scale_old + S_chunk * scale_new;
|
|
8715
|
+
M_final = M_new;
|
|
8716
|
+
}
|
|
8717
|
+
|
|
8718
|
+
// Normalize and write to output
|
|
8719
|
+
if (S_final != 0.0f) {
|
|
8720
|
+
const float S_inv = 1.0f / S_final;
|
|
8721
|
+
ggml_vec_scale_f32(DV, VKQ_final, S_inv);
|
|
8722
|
+
}
|
|
8723
|
+
// iq1=0, iq3=0 for decode
|
|
8724
|
+
memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
|
|
8725
|
+
}
|
|
8726
|
+
}
|
|
8727
|
+
|
|
8728
|
+
static void ggml_compute_forward_flash_attn_ext_f16(
|
|
8729
|
+
const ggml_compute_params * params,
|
|
8730
|
+
ggml_tensor * dst) {
|
|
8731
|
+
|
|
8732
|
+
const ggml_tensor * q = dst->src[0];
|
|
8733
|
+
const ggml_tensor * k = dst->src[1];
|
|
8734
|
+
const ggml_tensor * v = dst->src[2];
|
|
8735
|
+
|
|
8736
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8737
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8738
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8739
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8740
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8741
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8742
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8743
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8744
|
+
|
|
8745
|
+
const int64_t DK = nek0;
|
|
8746
|
+
const int64_t DV = nev0;
|
|
8747
|
+
const int64_t N = neq1;
|
|
8748
|
+
|
|
8749
|
+
|
|
8750
|
+
GGML_ASSERT(ne0 == DV);
|
|
8751
|
+
GGML_ASSERT(ne2 == N);
|
|
8752
|
+
|
|
8753
|
+
// input tensor rows must be contiguous
|
|
8754
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
8755
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
8756
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
8757
|
+
|
|
8758
|
+
GGML_ASSERT(neq0 == DK);
|
|
8759
|
+
GGML_ASSERT(nek0 == DK);
|
|
8760
|
+
GGML_ASSERT(nev0 == DV);
|
|
8761
|
+
|
|
8762
|
+
GGML_ASSERT(neq1 == N);
|
|
8763
|
+
|
|
8764
|
+
// dst cannot be transposed or permuted
|
|
8765
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8766
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
8767
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
8768
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
8769
|
+
|
|
8770
|
+
const int ith = params->ith;
|
|
8771
|
+
const int nth = params->nth;
|
|
8772
|
+
|
|
8773
|
+
// When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
|
|
8774
|
+
const bool use_ref = params->use_ref;
|
|
8775
|
+
|
|
8776
|
+
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
|
|
8777
|
+
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
|
|
8778
|
+
|
|
8779
|
+
if (use_split_kv_path) {
|
|
8780
|
+
const int64_t chunk_size = (nek1 + nth - 1) / nth;
|
|
8781
|
+
|
|
8782
|
+
// Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
|
|
8783
|
+
const int64_t partial_size = 2 + DV;
|
|
8784
|
+
float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
|
|
8785
|
+
|
|
8786
|
+
const int64_t ic_start = ith * chunk_size;
|
|
8787
|
+
const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
|
|
8788
|
+
|
|
8789
|
+
const int64_t partial_stride = nth * partial_size;
|
|
8790
|
+
float * chunk_partials = partials_base + ith * partial_size;
|
|
8791
|
+
|
|
8792
|
+
if (ic_start < nek1) {
|
|
8793
|
+
for (int64_t q_head = 0; q_head < neq2; q_head++) {
|
|
8794
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
8795
|
+
params, dst, q_head, q_head + 1, ic_start, ic_end,
|
|
8796
|
+
chunk_partials, partial_stride);
|
|
8797
|
+
}
|
|
8798
|
+
} else {
|
|
8799
|
+
for (int64_t q_head = 0; q_head < neq2; q_head++) {
|
|
8800
|
+
float * q_partials = chunk_partials + q_head * partial_stride;
|
|
8801
|
+
q_partials[0] = -INFINITY; // M
|
|
8802
|
+
q_partials[1] = 0.0f; // S
|
|
8803
|
+
}
|
|
8804
|
+
}
|
|
8805
|
+
|
|
8806
|
+
ggml_barrier(params->threadpool);
|
|
8807
|
+
ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
|
|
8808
|
+
} else {
|
|
8809
|
+
|
|
8810
|
+
// total rows in q
|
|
8811
|
+
const int64_t nr = neq1*neq2*neq3;
|
|
8812
|
+
|
|
8813
|
+
// disable for NUMA
|
|
8814
|
+
const bool disable_chunking = ggml_is_numa();
|
|
8815
|
+
|
|
8816
|
+
// 4x chunks per thread
|
|
8817
|
+
int nth_scaled = nth * 4;
|
|
8818
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8819
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8820
|
+
|
|
8821
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
8822
|
+
nchunk = nth;
|
|
8823
|
+
}
|
|
8824
|
+
|
|
8825
|
+
if (ith == 0) {
|
|
8826
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
8827
|
+
}
|
|
8828
|
+
|
|
8829
|
+
ggml_barrier(params->threadpool);
|
|
8830
|
+
|
|
8831
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
8832
|
+
|
|
8833
|
+
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
|
8834
|
+
bool use_tiled = !use_ref &&
|
|
8835
|
+
(q->type == GGML_TYPE_F32 &&
|
|
8836
|
+
kv_is_f32_or_f16 &&
|
|
8837
|
+
k->type == v->type &&
|
|
8838
|
+
neq1 >= Q_TILE_SZ);
|
|
8839
|
+
#ifdef GGML_SIMD
|
|
8840
|
+
use_tiled &= (DV % GGML_F32_EPR == 0);
|
|
8841
|
+
#endif
|
|
8842
|
+
int current_chunk = ith;
|
|
8843
|
+
|
|
8844
|
+
while (current_chunk < nchunk) {
|
|
8845
|
+
const int64_t ir0 = dr * current_chunk;
|
|
8846
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
8847
|
+
|
|
8848
|
+
if (use_tiled) {
|
|
8849
|
+
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
|
8850
|
+
} else {
|
|
8851
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
|
|
8852
|
+
}
|
|
8853
|
+
|
|
8854
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
8855
|
+
}
|
|
8151
8856
|
}
|
|
8152
8857
|
}
|
|
8153
8858
|
|
|
@@ -8637,7 +9342,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8637
9342
|
// n_head
|
|
8638
9343
|
for (int h = ih0; h < ih1; ++h) {
|
|
8639
9344
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8640
|
-
const float dt_soft_plus =
|
|
9345
|
+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
|
8641
9346
|
const float dA = expf(dt_soft_plus * A[h]);
|
|
8642
9347
|
const int g = h / (nh / ng); // repeat_interleave
|
|
8643
9348
|
|
|
@@ -8734,7 +9439,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8734
9439
|
// n_head
|
|
8735
9440
|
for (int h = ih0; h < ih1; ++h) {
|
|
8736
9441
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8737
|
-
const float dt_soft_plus =
|
|
9442
|
+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
|
8738
9443
|
const int g = h / (nh / ng); // repeat_interleave
|
|
8739
9444
|
|
|
8740
9445
|
// dim
|
|
@@ -8928,7 +9633,7 @@ void ggml_compute_forward_win_unpart(
|
|
|
8928
9633
|
}
|
|
8929
9634
|
}
|
|
8930
9635
|
|
|
8931
|
-
//
|
|
9636
|
+
//ggml_compute_forward_unary
|
|
8932
9637
|
|
|
8933
9638
|
void ggml_compute_forward_unary(
|
|
8934
9639
|
const ggml_compute_params * params,
|
|
@@ -8997,6 +9702,34 @@ void ggml_compute_forward_unary(
|
|
|
8997
9702
|
{
|
|
8998
9703
|
ggml_compute_forward_exp(params, dst);
|
|
8999
9704
|
} break;
|
|
9705
|
+
case GGML_UNARY_OP_FLOOR:
|
|
9706
|
+
{
|
|
9707
|
+
ggml_compute_forward_floor(params, dst);
|
|
9708
|
+
} break;
|
|
9709
|
+
case GGML_UNARY_OP_CEIL:
|
|
9710
|
+
{
|
|
9711
|
+
ggml_compute_forward_ceil(params, dst);
|
|
9712
|
+
} break;
|
|
9713
|
+
case GGML_UNARY_OP_ROUND:
|
|
9714
|
+
{
|
|
9715
|
+
ggml_compute_forward_round(params, dst);
|
|
9716
|
+
} break;
|
|
9717
|
+
case GGML_UNARY_OP_TRUNC:
|
|
9718
|
+
{
|
|
9719
|
+
ggml_compute_forward_trunc(params, dst);
|
|
9720
|
+
} break;
|
|
9721
|
+
case GGML_UNARY_OP_XIELU:
|
|
9722
|
+
{
|
|
9723
|
+
ggml_compute_forward_xielu(params, dst);
|
|
9724
|
+
} break;
|
|
9725
|
+
case GGML_UNARY_OP_EXPM1:
|
|
9726
|
+
{
|
|
9727
|
+
ggml_compute_forward_expm1(params, dst);
|
|
9728
|
+
} break;
|
|
9729
|
+
case GGML_UNARY_OP_SOFTPLUS:
|
|
9730
|
+
{
|
|
9731
|
+
ggml_compute_forward_softplus(params, dst);
|
|
9732
|
+
} break;
|
|
9000
9733
|
default:
|
|
9001
9734
|
{
|
|
9002
9735
|
GGML_ABORT("fatal error");
|
|
@@ -9593,6 +10326,265 @@ void ggml_compute_forward_gla(
|
|
|
9593
10326
|
}
|
|
9594
10327
|
}
|
|
9595
10328
|
|
|
10329
|
+
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
10330
|
+
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
|
10331
|
+
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
|
10332
|
+
|
|
10333
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
10334
|
+
|
|
10335
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
10336
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
10337
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
10338
|
+
|
|
10339
|
+
GGML_ASSERT(ne00 == ne01); // A must be square
|
|
10340
|
+
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
|
10341
|
+
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
|
10342
|
+
|
|
10343
|
+
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
|
10344
|
+
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
|
10345
|
+
|
|
10346
|
+
const int ith = params->ith;
|
|
10347
|
+
const int nth = params->nth;
|
|
10348
|
+
|
|
10349
|
+
const int64_t k = ne10; // number of RHS columns
|
|
10350
|
+
const int64_t n = ne11; // A is n×n
|
|
10351
|
+
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
|
10352
|
+
|
|
10353
|
+
// chunks per thread
|
|
10354
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
|
10355
|
+
|
|
10356
|
+
// chunk range for this thread
|
|
10357
|
+
const int64_t ir0 = dr*ith;
|
|
10358
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
10359
|
+
|
|
10360
|
+
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
|
10361
|
+
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
|
10362
|
+
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
|
10363
|
+
|
|
10364
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
10365
|
+
const int64_t i03 = ir/(ne02*k);
|
|
10366
|
+
const int64_t i02 = (ir - i03*ne02*k)/k;
|
|
10367
|
+
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
|
10368
|
+
|
|
10369
|
+
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
|
10370
|
+
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
|
10371
|
+
|
|
10372
|
+
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
|
10373
|
+
|
|
10374
|
+
for (int64_t i00 = 0; i00 < n; ++i00) {
|
|
10375
|
+
float sum = 0.0f;
|
|
10376
|
+
for (int64_t t = 0; t < i00; ++t) {
|
|
10377
|
+
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
|
10378
|
+
}
|
|
10379
|
+
|
|
10380
|
+
const float diag = A_batch[i00 * n + i00];
|
|
10381
|
+
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
|
|
10382
|
+
|
|
10383
|
+
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
|
10384
|
+
}
|
|
10385
|
+
}
|
|
10386
|
+
}
|
|
10387
|
+
|
|
10388
|
+
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
10389
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10390
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
10391
|
+
|
|
10392
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
10393
|
+
ggml_compute_forward_solve_tri_f32(params, dst);
|
|
10394
|
+
} else {
|
|
10395
|
+
GGML_ABORT("fatal error");
|
|
10396
|
+
}
|
|
10397
|
+
}
|
|
10398
|
+
|
|
10399
|
+
// ggml_compute_forward_gated_delta_net
|
|
10400
|
+
static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|
10401
|
+
const ggml_compute_params * params,
|
|
10402
|
+
ggml_tensor * dst,
|
|
10403
|
+
int64_t ir0,
|
|
10404
|
+
int64_t ir1) {
|
|
10405
|
+
|
|
10406
|
+
ggml_tensor * src_q = dst->src[0];
|
|
10407
|
+
ggml_tensor * src_k = dst->src[1];
|
|
10408
|
+
ggml_tensor * src_v = dst->src[2];
|
|
10409
|
+
ggml_tensor * src_g = dst->src[3];
|
|
10410
|
+
ggml_tensor * src_beta = dst->src[4];
|
|
10411
|
+
ggml_tensor * src_state = dst->src[5];
|
|
10412
|
+
|
|
10413
|
+
const int64_t S_v = src_v->ne[0];
|
|
10414
|
+
const int64_t H = src_v->ne[1];
|
|
10415
|
+
const int64_t n_tokens = src_v->ne[2];
|
|
10416
|
+
const int64_t n_seqs = src_v->ne[3];
|
|
10417
|
+
|
|
10418
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
|
|
10419
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
|
|
10420
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
|
|
10421
|
+
GGML_ASSERT(ggml_is_contiguous(src_g));
|
|
10422
|
+
GGML_ASSERT(ggml_is_contiguous(src_beta));
|
|
10423
|
+
GGML_ASSERT(ggml_is_contiguous(src_state));
|
|
10424
|
+
|
|
10425
|
+
GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
|
|
10426
|
+
GGML_ASSERT(src_beta->ne[0] == 1);
|
|
10427
|
+
|
|
10428
|
+
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
|
|
10429
|
+
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
|
|
10430
|
+
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
|
|
10431
|
+
GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
|
|
10432
|
+
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
|
|
10433
|
+
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
|
|
10434
|
+
GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
|
|
10435
|
+
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
|
|
10436
|
+
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
|
|
10437
|
+
|
|
10438
|
+
const bool kda = (neg0 == S_v);
|
|
10439
|
+
|
|
10440
|
+
// scratch layout per thread: [delta(S_v)]
|
|
10441
|
+
const int64_t scratch_per_thread = S_v;
|
|
10442
|
+
const int ith = params->ith;
|
|
10443
|
+
|
|
10444
|
+
float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
|
|
10445
|
+
|
|
10446
|
+
// output layout: [attn_scores | new_states]
|
|
10447
|
+
// attn_scores: S_v * H * n_tokens * n_seqs floats
|
|
10448
|
+
// new_states: S_v * S_v * H * n_seqs floats
|
|
10449
|
+
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
|
10450
|
+
float * attn_out_base = (float *)dst->data;
|
|
10451
|
+
float * state_out_base = (float *)dst->data + attn_score_elems;
|
|
10452
|
+
|
|
10453
|
+
const float * state_in_base = (const float *)src_state->data;
|
|
10454
|
+
|
|
10455
|
+
//const int64_t rq1 = nev1 / neq1;
|
|
10456
|
+
//const int64_t rk1 = nev1 / nek1;
|
|
10457
|
+
const int64_t rq3 = nev3 / neq3;
|
|
10458
|
+
const int64_t rk3 = nev3 / nek3;
|
|
10459
|
+
|
|
10460
|
+
const float scale = 1.0f / sqrtf((float) S_v);
|
|
10461
|
+
|
|
10462
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
10463
|
+
const int64_t iv1 = ir % H; // head_index
|
|
10464
|
+
const int64_t iv3 = ir / H; // sequence
|
|
10465
|
+
|
|
10466
|
+
const int64_t iq1 = iv1 % neq1;
|
|
10467
|
+
const int64_t ik1 = iv1 % nek1;
|
|
10468
|
+
|
|
10469
|
+
const int64_t iq3 = iv3 / rq3;
|
|
10470
|
+
const int64_t ik3 = iv3 / rk3;
|
|
10471
|
+
|
|
10472
|
+
float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
|
|
10473
|
+
|
|
10474
|
+
// copy input state into output buffer and operate in-place
|
|
10475
|
+
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
|
|
10476
|
+
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
|
|
10477
|
+
|
|
10478
|
+
// attn output pointer for first token of this (head, seq)
|
|
10479
|
+
float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
|
|
10480
|
+
|
|
10481
|
+
for (int64_t t = 0; t < n_tokens; t++) {
|
|
10482
|
+
const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
|
|
10483
|
+
const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
|
|
10484
|
+
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
|
|
10485
|
+
|
|
10486
|
+
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
|
|
10487
|
+
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
|
|
10488
|
+
|
|
10489
|
+
// state is stored transposed: s_out[j*S_v + i] = S[i][j]
|
|
10490
|
+
// so row j of s_out = column j of S (contiguous access)
|
|
10491
|
+
|
|
10492
|
+
if (kda) {
|
|
10493
|
+
// precompute exp(g) into delta scratch (reused below)
|
|
10494
|
+
for (int64_t i = 0; i < S_v; ++i) {
|
|
10495
|
+
delta[i] = expf(g_d[i]);
|
|
10496
|
+
}
|
|
10497
|
+
// S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
|
|
10498
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10499
|
+
ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
|
|
10500
|
+
}
|
|
10501
|
+
} else {
|
|
10502
|
+
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
|
|
10503
|
+
}
|
|
10504
|
+
|
|
10505
|
+
// delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
|
|
10506
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10507
|
+
float sum = 0.0f;
|
|
10508
|
+
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
|
|
10509
|
+
delta[j] = (v_d[j] - sum) * beta_val;
|
|
10510
|
+
}
|
|
10511
|
+
|
|
10512
|
+
// outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
|
|
10513
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10514
|
+
ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
|
|
10515
|
+
}
|
|
10516
|
+
|
|
10517
|
+
// attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
|
|
10518
|
+
for (int64_t j = 0; j < S_v; ++j) {
|
|
10519
|
+
float sum = 0.0f;
|
|
10520
|
+
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
|
|
10521
|
+
attn_data[j] = sum * scale;
|
|
10522
|
+
}
|
|
10523
|
+
|
|
10524
|
+
attn_data += S_v * H; // advance to next token
|
|
10525
|
+
}
|
|
10526
|
+
}
|
|
10527
|
+
}
|
|
10528
|
+
|
|
10529
|
+
|
|
10530
|
+
static void ggml_compute_forward_gated_delta_net_f32(
|
|
10531
|
+
const ggml_compute_params * params,
|
|
10532
|
+
ggml_tensor * dst) {
|
|
10533
|
+
|
|
10534
|
+
ggml_tensor * V = dst->src[2];
|
|
10535
|
+
int64_t nr = V->ne[1] * V->ne[3];
|
|
10536
|
+
|
|
10537
|
+
// disable for NUMA
|
|
10538
|
+
const bool disable_chunking = ggml_is_numa();
|
|
10539
|
+
|
|
10540
|
+
int nth = params->nth;
|
|
10541
|
+
int ith = params->ith;
|
|
10542
|
+
|
|
10543
|
+
// 4x chunks per thread
|
|
10544
|
+
int nth_scaled = nth * 4;
|
|
10545
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
10546
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
10547
|
+
|
|
10548
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
10549
|
+
nchunk = nth;
|
|
10550
|
+
}
|
|
10551
|
+
|
|
10552
|
+
if (ith == 0) {
|
|
10553
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
10554
|
+
}
|
|
10555
|
+
|
|
10556
|
+
ggml_barrier(params->threadpool);
|
|
10557
|
+
|
|
10558
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
10559
|
+
|
|
10560
|
+
int current_chunk = ith;
|
|
10561
|
+
|
|
10562
|
+
while (current_chunk < nchunk) {
|
|
10563
|
+
const int64_t ir0 = dr * current_chunk;
|
|
10564
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
10565
|
+
|
|
10566
|
+
ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
|
|
10567
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
10568
|
+
}
|
|
10569
|
+
}
|
|
10570
|
+
|
|
10571
|
+
void ggml_compute_forward_gated_delta_net(
|
|
10572
|
+
const ggml_compute_params * params,
|
|
10573
|
+
ggml_tensor * dst) {
|
|
10574
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10575
|
+
|
|
10576
|
+
switch (src0->type) {
|
|
10577
|
+
case GGML_TYPE_F32:
|
|
10578
|
+
{
|
|
10579
|
+
ggml_compute_forward_gated_delta_net_f32(params, dst);
|
|
10580
|
+
} break;
|
|
10581
|
+
default:
|
|
10582
|
+
{
|
|
10583
|
+
GGML_ABORT("fatal error");
|
|
10584
|
+
}
|
|
10585
|
+
}
|
|
10586
|
+
}
|
|
10587
|
+
|
|
9596
10588
|
// ggml_compute_forward_rwkv_wkv7
|
|
9597
10589
|
|
|
9598
10590
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
@@ -9918,7 +10910,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|
|
9918
10910
|
assert(!isnan(s0[i]));
|
|
9919
10911
|
assert(!isnan(s1[i]));
|
|
9920
10912
|
}
|
|
9921
|
-
#endif
|
|
10913
|
+
#endif // NDEBUG
|
|
9922
10914
|
|
|
9923
10915
|
float max = -INFINITY;
|
|
9924
10916
|
ggml_vec_max_f32(nc, &max, s0);
|
|
@@ -9937,7 +10929,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
|
|
9937
10929
|
assert(!isnan(st[i]));
|
|
9938
10930
|
assert(!isinf(st[i]));
|
|
9939
10931
|
}
|
|
9940
|
-
#endif
|
|
10932
|
+
#endif // NDEBUG
|
|
9941
10933
|
}
|
|
9942
10934
|
sums[ith] = sum_thread;
|
|
9943
10935
|
ggml_barrier(params->threadpool);
|
|
@@ -10010,7 +11002,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
10010
11002
|
assert(!isnan(s0[i]));
|
|
10011
11003
|
assert(!isnan(s1[i]));
|
|
10012
11004
|
}
|
|
10013
|
-
#endif
|
|
11005
|
+
#endif // NDEBUG
|
|
10014
11006
|
|
|
10015
11007
|
// soft_max
|
|
10016
11008
|
float max = -INFINITY;
|
|
@@ -10028,7 +11020,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
|
10028
11020
|
assert(!isnan(ds0[i]));
|
|
10029
11021
|
assert(!isinf(ds0[i]));
|
|
10030
11022
|
}
|
|
10031
|
-
#endif
|
|
11023
|
+
#endif // NDEBUG
|
|
10032
11024
|
}
|
|
10033
11025
|
}
|
|
10034
11026
|
|