whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -7,8 +7,10 @@
|
|
|
7
7
|
#include "unary-ops.h"
|
|
8
8
|
#include "vec.h"
|
|
9
9
|
|
|
10
|
-
#include <
|
|
10
|
+
#include <cfloat>
|
|
11
11
|
#include <algorithm>
|
|
12
|
+
#include <cmath>
|
|
13
|
+
#include <functional>
|
|
12
14
|
|
|
13
15
|
// ggml_compute_forward_dup
|
|
14
16
|
|
|
@@ -1394,6 +1396,56 @@ void ggml_compute_forward_sum(
|
|
|
1394
1396
|
}
|
|
1395
1397
|
}
|
|
1396
1398
|
|
|
1399
|
+
// ggml_compute_forward_cumsum
|
|
1400
|
+
|
|
1401
|
+
static void ggml_compute_forward_cumsum_f32(
|
|
1402
|
+
const ggml_compute_params * params,
|
|
1403
|
+
ggml_tensor * dst) {
|
|
1404
|
+
|
|
1405
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1406
|
+
|
|
1407
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
1408
|
+
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
|
1409
|
+
|
|
1410
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
1411
|
+
|
|
1412
|
+
GGML_ASSERT(ne0 == ne00);
|
|
1413
|
+
GGML_ASSERT(ne1 == ne01);
|
|
1414
|
+
GGML_ASSERT(ne2 == ne02);
|
|
1415
|
+
GGML_ASSERT(ne3 == ne03);
|
|
1416
|
+
|
|
1417
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
1418
|
+
|
|
1419
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
1420
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
1421
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
1422
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
1423
|
+
|
|
1424
|
+
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
1425
|
+
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
1426
|
+
|
|
1427
|
+
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
|
1428
|
+
}
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
void ggml_compute_forward_cumsum(
|
|
1432
|
+
const ggml_compute_params * params,
|
|
1433
|
+
ggml_tensor * dst) {
|
|
1434
|
+
|
|
1435
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1436
|
+
|
|
1437
|
+
switch (src0->type) {
|
|
1438
|
+
case GGML_TYPE_F32:
|
|
1439
|
+
{
|
|
1440
|
+
ggml_compute_forward_cumsum_f32(params, dst);
|
|
1441
|
+
} break;
|
|
1442
|
+
default:
|
|
1443
|
+
{
|
|
1444
|
+
GGML_ABORT("fatal error");
|
|
1445
|
+
}
|
|
1446
|
+
}
|
|
1447
|
+
}
|
|
1448
|
+
|
|
1397
1449
|
// ggml_compute_forward_sum_rows
|
|
1398
1450
|
|
|
1399
1451
|
static void ggml_compute_forward_sum_rows_f32(
|
|
@@ -2140,6 +2192,83 @@ static void ggml_compute_forward_gelu(
|
|
|
2140
2192
|
}
|
|
2141
2193
|
}
|
|
2142
2194
|
|
|
2195
|
+
// ggml_compute_fill
|
|
2196
|
+
|
|
2197
|
+
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2198
|
+
const float c = ggml_get_op_params_f32(dst, 0);
|
|
2199
|
+
|
|
2200
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
|
2201
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
|
2202
|
+
|
|
2203
|
+
const auto [ir0, ir1] = get_thread_range(params, dst);
|
|
2204
|
+
|
|
2205
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2206
|
+
const int64_t i03 = ir/(ne2*ne1);
|
|
2207
|
+
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
|
2208
|
+
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
|
2209
|
+
|
|
2210
|
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2211
|
+
|
|
2212
|
+
ggml_vec_set_f32(ne0, dst_ptr, c);
|
|
2213
|
+
}
|
|
2214
|
+
}
|
|
2215
|
+
|
|
2216
|
+
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2217
|
+
ggml_compute_forward_fill_f32(params, dst);
|
|
2218
|
+
}
|
|
2219
|
+
|
|
2220
|
+
// ggml_compute_tri
|
|
2221
|
+
|
|
2222
|
+
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2223
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2224
|
+
|
|
2225
|
+
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
|
2226
|
+
|
|
2227
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2228
|
+
|
|
2229
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
2230
|
+
|
|
2231
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
2232
|
+
|
|
2233
|
+
bool (*bipred)(int, int);
|
|
2234
|
+
|
|
2235
|
+
switch (ttype) {
|
|
2236
|
+
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
|
2237
|
+
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
|
2238
|
+
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
|
2239
|
+
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
|
2240
|
+
default: GGML_ABORT("invalid tri type");
|
|
2241
|
+
}
|
|
2242
|
+
|
|
2243
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2244
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
2245
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
2246
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
2247
|
+
|
|
2248
|
+
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
2249
|
+
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2250
|
+
|
|
2251
|
+
for (int i0 = 0; i0 < ne0; ++i0) {
|
|
2252
|
+
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
|
2253
|
+
}
|
|
2254
|
+
}
|
|
2255
|
+
}
|
|
2256
|
+
|
|
2257
|
+
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2258
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2259
|
+
|
|
2260
|
+
switch (src0->type) {
|
|
2261
|
+
case GGML_TYPE_F32:
|
|
2262
|
+
{
|
|
2263
|
+
ggml_compute_forward_tri_f32(params, dst);
|
|
2264
|
+
} break;
|
|
2265
|
+
default:
|
|
2266
|
+
{
|
|
2267
|
+
GGML_ABORT("fatal error");
|
|
2268
|
+
}
|
|
2269
|
+
}
|
|
2270
|
+
}
|
|
2271
|
+
|
|
2143
2272
|
// ggml_compute_forward_gelu_erf
|
|
2144
2273
|
|
|
2145
2274
|
static void ggml_compute_forward_gelu_erf_f32(
|
|
@@ -3467,31 +3596,27 @@ static void ggml_compute_forward_norm_f32(
|
|
|
3467
3596
|
|
|
3468
3597
|
GGML_ASSERT(eps >= 0.0f);
|
|
3469
3598
|
|
|
3470
|
-
// TODO: optimize
|
|
3471
3599
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
3472
3600
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
3473
3601
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
3474
3602
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
3475
3603
|
|
|
3476
|
-
|
|
3477
|
-
|
|
3478
|
-
sum += (ggml_float)x[i00];
|
|
3479
|
-
}
|
|
3480
|
-
|
|
3604
|
+
float sum = 0.0;
|
|
3605
|
+
ggml_vec_sum_f32(ne00, &sum, x);
|
|
3481
3606
|
float mean = sum/ne00;
|
|
3482
3607
|
|
|
3483
3608
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3609
|
+
float variance = 0;
|
|
3484
3610
|
|
|
3485
|
-
|
|
3486
|
-
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
|
|
3611
|
+
#ifdef GGML_USE_ACCELERATE
|
|
3612
|
+
mean = -mean;
|
|
3613
|
+
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
|
3614
|
+
vDSP_measqv(y, 1, &variance, ne00);
|
|
3615
|
+
#else
|
|
3616
|
+
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
|
3617
|
+
#endif //GGML_USE_ACCELERATE
|
|
3491
3618
|
|
|
3492
|
-
float variance = sum2/ne00;
|
|
3493
3619
|
const float scale = 1.0f/sqrtf(variance + eps);
|
|
3494
|
-
|
|
3495
3620
|
ggml_vec_scale_f32(ne00, y, scale);
|
|
3496
3621
|
}
|
|
3497
3622
|
}
|
|
@@ -4459,46 +4584,6 @@ void ggml_compute_forward_cont(
|
|
|
4459
4584
|
ggml_compute_forward_dup(params, dst);
|
|
4460
4585
|
}
|
|
4461
4586
|
|
|
4462
|
-
// ggml_compute_forward_reshape
|
|
4463
|
-
|
|
4464
|
-
void ggml_compute_forward_reshape(
|
|
4465
|
-
const ggml_compute_params * params,
|
|
4466
|
-
ggml_tensor * dst) {
|
|
4467
|
-
// NOP
|
|
4468
|
-
GGML_UNUSED(params);
|
|
4469
|
-
GGML_UNUSED(dst);
|
|
4470
|
-
}
|
|
4471
|
-
|
|
4472
|
-
// ggml_compute_forward_view
|
|
4473
|
-
|
|
4474
|
-
void ggml_compute_forward_view(
|
|
4475
|
-
const ggml_compute_params * params,
|
|
4476
|
-
ggml_tensor * dst) {
|
|
4477
|
-
// NOP
|
|
4478
|
-
GGML_UNUSED(params);
|
|
4479
|
-
GGML_UNUSED(dst);
|
|
4480
|
-
}
|
|
4481
|
-
|
|
4482
|
-
// ggml_compute_forward_permute
|
|
4483
|
-
|
|
4484
|
-
void ggml_compute_forward_permute(
|
|
4485
|
-
const ggml_compute_params * params,
|
|
4486
|
-
ggml_tensor * dst) {
|
|
4487
|
-
// NOP
|
|
4488
|
-
GGML_UNUSED(params);
|
|
4489
|
-
GGML_UNUSED(dst);
|
|
4490
|
-
}
|
|
4491
|
-
|
|
4492
|
-
// ggml_compute_forward_transpose
|
|
4493
|
-
|
|
4494
|
-
void ggml_compute_forward_transpose(
|
|
4495
|
-
const ggml_compute_params * params,
|
|
4496
|
-
ggml_tensor * dst) {
|
|
4497
|
-
// NOP
|
|
4498
|
-
GGML_UNUSED(params);
|
|
4499
|
-
GGML_UNUSED(dst);
|
|
4500
|
-
}
|
|
4501
|
-
|
|
4502
4587
|
// ggml_compute_forward_get_rows
|
|
4503
4588
|
|
|
4504
4589
|
static void ggml_compute_forward_get_rows_q(
|
|
@@ -5478,7 +5563,7 @@ static void ggml_rope_cache_init(
|
|
|
5478
5563
|
}
|
|
5479
5564
|
|
|
5480
5565
|
static void ggml_mrope_cache_init(
|
|
5481
|
-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
|
5566
|
+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
|
5482
5567
|
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5483
5568
|
float * cache, float sin_sign, float theta_scale) {
|
|
5484
5569
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
@@ -5513,14 +5598,26 @@ static void ggml_mrope_cache_init(
|
|
|
5513
5598
|
}
|
|
5514
5599
|
|
|
5515
5600
|
float theta = theta_t;
|
|
5516
|
-
if (
|
|
5517
|
-
|
|
5518
|
-
|
|
5519
|
-
|
|
5520
|
-
|
|
5521
|
-
|
|
5522
|
-
|
|
5523
|
-
|
|
5601
|
+
if (is_imrope) { // qwen3vl apply interleaved mrope
|
|
5602
|
+
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
|
5603
|
+
theta = theta_h;
|
|
5604
|
+
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
|
5605
|
+
theta = theta_w;
|
|
5606
|
+
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
|
5607
|
+
theta = theta_t;
|
|
5608
|
+
} else {
|
|
5609
|
+
theta = theta_e;
|
|
5610
|
+
}
|
|
5611
|
+
} else {
|
|
5612
|
+
if (sector >= sections[0] && sector < sec_w) {
|
|
5613
|
+
theta = theta_h;
|
|
5614
|
+
}
|
|
5615
|
+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
5616
|
+
theta = theta_w;
|
|
5617
|
+
}
|
|
5618
|
+
else if (sector >= sec_w + sections[2]) {
|
|
5619
|
+
theta = theta_e;
|
|
5620
|
+
}
|
|
5524
5621
|
}
|
|
5525
5622
|
|
|
5526
5623
|
rope_yarn(
|
|
@@ -5535,7 +5632,28 @@ static void ggml_mrope_cache_init(
|
|
|
5535
5632
|
}
|
|
5536
5633
|
}
|
|
5537
5634
|
|
|
5538
|
-
|
|
5635
|
+
|
|
5636
|
+
template<typename T>
|
|
5637
|
+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
|
|
5638
|
+
for (int64_t i0 = 0; i0 < n; i0 += 2) {
|
|
5639
|
+
const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
|
|
5640
|
+
|
|
5641
|
+
const float cos_theta = cache[i0 + 0];
|
|
5642
|
+
const float sin_theta = cache[i0 + 1];
|
|
5643
|
+
|
|
5644
|
+
const T * const src = src_data + ic;
|
|
5645
|
+
T * dst = dst_data + ic;
|
|
5646
|
+
|
|
5647
|
+
const float x0 = type_conversion_table<T>::to_f32(src[0]);
|
|
5648
|
+
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
|
|
5649
|
+
|
|
5650
|
+
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
|
|
5651
|
+
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
|
|
5652
|
+
}
|
|
5653
|
+
}
|
|
5654
|
+
|
|
5655
|
+
template<typename T> //float or ggml_fp16_t
|
|
5656
|
+
static void ggml_compute_forward_rope_flt(
|
|
5539
5657
|
const ggml_compute_params * params,
|
|
5540
5658
|
ggml_tensor * dst,
|
|
5541
5659
|
const bool forward) {
|
|
@@ -5544,6 +5662,9 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5544
5662
|
const ggml_tensor * src1 = dst->src[1];
|
|
5545
5663
|
const ggml_tensor * src2 = dst->src[2];
|
|
5546
5664
|
|
|
5665
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
5666
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
5667
|
+
|
|
5547
5668
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5548
5669
|
int sections[4];
|
|
5549
5670
|
|
|
@@ -5566,7 +5687,8 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5566
5687
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5567
5688
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5568
5689
|
|
|
5569
|
-
GGML_ASSERT(
|
|
5690
|
+
GGML_ASSERT(nb0 == nb00);
|
|
5691
|
+
GGML_ASSERT(nb0 == sizeof(T));
|
|
5570
5692
|
|
|
5571
5693
|
const int ith = params->ith;
|
|
5572
5694
|
const int nth = params->nth;
|
|
@@ -5591,11 +5713,11 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5591
5713
|
float corr_dims[2];
|
|
5592
5714
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5593
5715
|
|
|
5594
|
-
const bool
|
|
5595
|
-
const bool
|
|
5716
|
+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
|
5717
|
+
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
|
5596
5718
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
5597
5719
|
|
|
5598
|
-
if (
|
|
5720
|
+
if (mrope_used) {
|
|
5599
5721
|
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5600
5722
|
}
|
|
5601
5723
|
|
|
@@ -5621,7 +5743,7 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5621
5743
|
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5622
5744
|
|
|
5623
5745
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5624
|
-
if (!
|
|
5746
|
+
if (!mrope_used) {
|
|
5625
5747
|
const int64_t p = pos[i2];
|
|
5626
5748
|
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5627
5749
|
}
|
|
@@ -5631,7 +5753,7 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5631
5753
|
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5632
5754
|
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5633
5755
|
ggml_mrope_cache_init(
|
|
5634
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5756
|
+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5635
5757
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5636
5758
|
}
|
|
5637
5759
|
|
|
@@ -5639,347 +5761,115 @@ static void ggml_compute_forward_rope_f32(
|
|
|
5639
5761
|
if (ir++ < ir0) continue;
|
|
5640
5762
|
if (ir > ir1) break;
|
|
5641
5763
|
|
|
5642
|
-
|
|
5643
|
-
|
|
5644
|
-
|
|
5645
|
-
|
|
5646
|
-
|
|
5647
|
-
|
|
5648
|
-
|
|
5649
|
-
|
|
5650
|
-
|
|
5651
|
-
|
|
5652
|
-
|
|
5653
|
-
|
|
5654
|
-
|
|
5655
|
-
|
|
5656
|
-
|
|
5657
|
-
|
|
5658
|
-
|
|
5659
|
-
} else {
|
|
5660
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5661
|
-
const int64_t ic = i0/2;
|
|
5662
|
-
|
|
5663
|
-
const float cos_theta = cache[i0 + 0];
|
|
5664
|
-
const float sin_theta = cache[i0 + 1];
|
|
5665
|
-
|
|
5666
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5667
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5668
|
-
|
|
5669
|
-
const float x0 = src[0];
|
|
5670
|
-
const float x1 = src[n_dims/2];
|
|
5671
|
-
|
|
5672
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5673
|
-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
5674
|
-
}
|
|
5675
|
-
}
|
|
5676
|
-
} else {
|
|
5677
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5678
|
-
const float cos_theta = cache[i0 + 0];
|
|
5679
|
-
const float sin_theta = cache[i0 + 1];
|
|
5680
|
-
|
|
5681
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5682
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5683
|
-
|
|
5684
|
-
const float x0 = src[0];
|
|
5685
|
-
const float x1 = src[1];
|
|
5686
|
-
|
|
5687
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5688
|
-
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
5689
|
-
}
|
|
5764
|
+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
5765
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
5766
|
+
|
|
5767
|
+
switch (mode) {
|
|
5768
|
+
case GGML_ROPE_TYPE_NORMAL:
|
|
5769
|
+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
|
|
5770
|
+
break;
|
|
5771
|
+
case GGML_ROPE_TYPE_NEOX:
|
|
5772
|
+
case GGML_ROPE_TYPE_MROPE:
|
|
5773
|
+
case GGML_ROPE_TYPE_IMROPE:
|
|
5774
|
+
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
|
|
5775
|
+
break;
|
|
5776
|
+
case GGML_ROPE_TYPE_VISION:
|
|
5777
|
+
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
|
|
5778
|
+
break;
|
|
5779
|
+
default:
|
|
5780
|
+
GGML_ABORT("rope type not supported");
|
|
5690
5781
|
}
|
|
5691
5782
|
|
|
5692
|
-
if (is_vision) {
|
|
5693
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5694
|
-
const int64_t ic = i0/2;
|
|
5695
|
-
|
|
5696
|
-
const float cos_theta = cache[i0 + 0];
|
|
5697
|
-
const float sin_theta = cache[i0 + 1];
|
|
5698
|
-
|
|
5699
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5700
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5701
|
-
|
|
5702
|
-
const float x0 = src[0];
|
|
5703
|
-
const float x1 = src[n_dims];
|
|
5704
|
-
|
|
5705
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
5706
|
-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
5707
|
-
}
|
|
5708
|
-
} else {
|
|
5783
|
+
if (!is_vision) {
|
|
5709
5784
|
// fill the remain channels with data from src tensor
|
|
5710
5785
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5711
|
-
const
|
|
5712
|
-
|
|
5786
|
+
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5787
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5713
5788
|
|
|
5714
5789
|
dst_data[0] = src[0];
|
|
5715
5790
|
dst_data[1] = src[1];
|
|
5716
5791
|
}
|
|
5717
5792
|
}
|
|
5718
|
-
}
|
|
5793
|
+
} //attn-heads
|
|
5719
5794
|
}
|
|
5720
5795
|
}
|
|
5721
5796
|
}
|
|
5722
5797
|
|
|
5723
|
-
|
|
5724
|
-
static void ggml_compute_forward_rope_f16(
|
|
5798
|
+
void ggml_compute_forward_rope(
|
|
5725
5799
|
const ggml_compute_params * params,
|
|
5726
|
-
ggml_tensor * dst
|
|
5727
|
-
const bool forward) {
|
|
5800
|
+
ggml_tensor * dst) {
|
|
5728
5801
|
|
|
5729
5802
|
const ggml_tensor * src0 = dst->src[0];
|
|
5730
|
-
const ggml_tensor * src1 = dst->src[1];
|
|
5731
|
-
const ggml_tensor * src2 = dst->src[2];
|
|
5732
5803
|
|
|
5733
|
-
|
|
5734
|
-
|
|
5804
|
+
switch (src0->type) {
|
|
5805
|
+
case GGML_TYPE_F16:
|
|
5806
|
+
{
|
|
5807
|
+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
|
|
5808
|
+
} break;
|
|
5809
|
+
case GGML_TYPE_F32:
|
|
5810
|
+
{
|
|
5811
|
+
ggml_compute_forward_rope_flt<float>(params, dst, true);
|
|
5812
|
+
} break;
|
|
5813
|
+
default:
|
|
5814
|
+
{
|
|
5815
|
+
GGML_ABORT("fatal error");
|
|
5816
|
+
}
|
|
5817
|
+
}
|
|
5818
|
+
}
|
|
5735
5819
|
|
|
5736
|
-
|
|
5737
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
5738
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
5739
|
-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
5740
|
-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5741
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
5742
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
5743
|
-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
5744
|
-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
5745
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
5746
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
5747
|
-
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
5820
|
+
// ggml_compute_forward_rope_back
|
|
5748
5821
|
|
|
5822
|
+
void ggml_compute_forward_rope_back(
|
|
5823
|
+
const ggml_compute_params * params,
|
|
5824
|
+
ggml_tensor * dst) {
|
|
5749
5825
|
|
|
5750
|
-
|
|
5826
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
5751
5827
|
|
|
5752
|
-
|
|
5753
|
-
|
|
5828
|
+
switch (src0->type) {
|
|
5829
|
+
case GGML_TYPE_F16:
|
|
5830
|
+
{
|
|
5831
|
+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
|
|
5832
|
+
} break;
|
|
5833
|
+
case GGML_TYPE_F32:
|
|
5834
|
+
{
|
|
5835
|
+
ggml_compute_forward_rope_flt<float>(params, dst, false);
|
|
5836
|
+
} break;
|
|
5837
|
+
default:
|
|
5838
|
+
{
|
|
5839
|
+
GGML_ABORT("fatal error");
|
|
5840
|
+
}
|
|
5841
|
+
}
|
|
5842
|
+
}
|
|
5754
5843
|
|
|
5755
|
-
|
|
5844
|
+
// ggml_compute_forward_conv_transpose_1d
|
|
5756
5845
|
|
|
5757
|
-
|
|
5758
|
-
|
|
5846
|
+
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
|
5847
|
+
const ggml_compute_params * params,
|
|
5848
|
+
ggml_tensor * dst) {
|
|
5759
5849
|
|
|
5760
|
-
const
|
|
5850
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
5851
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
5761
5852
|
|
|
5762
|
-
GGML_ASSERT(
|
|
5763
|
-
GGML_ASSERT(
|
|
5853
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
5854
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
5855
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
5764
5856
|
|
|
5765
|
-
|
|
5766
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5857
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
5767
5858
|
|
|
5768
|
-
|
|
5769
|
-
const int
|
|
5770
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5859
|
+
const int ith = params->ith;
|
|
5860
|
+
const int nth = params->nth;
|
|
5771
5861
|
|
|
5772
|
-
|
|
5773
|
-
int ir = 0;
|
|
5862
|
+
const int nk = ne00*ne01*ne02;
|
|
5774
5863
|
|
|
5775
|
-
|
|
5864
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
|
5865
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
|
5776
5866
|
|
|
5777
|
-
|
|
5778
|
-
|
|
5867
|
+
if (ith == 0) {
|
|
5868
|
+
memset(params->wdata, 0, params->wsize);
|
|
5779
5869
|
|
|
5780
|
-
|
|
5781
|
-
|
|
5782
|
-
|
|
5783
|
-
|
|
5784
|
-
if (is_mrope) {
|
|
5785
|
-
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5786
|
-
}
|
|
5787
|
-
|
|
5788
|
-
if (is_vision) {
|
|
5789
|
-
GGML_ASSERT(n_dims == ne0/2);
|
|
5790
|
-
}
|
|
5791
|
-
|
|
5792
|
-
const float * freq_factors = NULL;
|
|
5793
|
-
if (src2 != NULL) {
|
|
5794
|
-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
|
5795
|
-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
|
5796
|
-
freq_factors = (const float *) src2->data;
|
|
5797
|
-
}
|
|
5798
|
-
|
|
5799
|
-
// backward process uses inverse rotation by cos and sin.
|
|
5800
|
-
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
5801
|
-
// this essentially just switches the sign of sin.
|
|
5802
|
-
const float sin_sign = forward ? 1.0f : -1.0f;
|
|
5803
|
-
|
|
5804
|
-
const int32_t * pos = (const int32_t *) src1->data;
|
|
5805
|
-
|
|
5806
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
5807
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
5808
|
-
|
|
5809
|
-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5810
|
-
if (!is_mrope) {
|
|
5811
|
-
const int64_t p = pos[i2];
|
|
5812
|
-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5813
|
-
}
|
|
5814
|
-
else {
|
|
5815
|
-
const int64_t p_t = pos[i2];
|
|
5816
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
5817
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5818
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5819
|
-
ggml_mrope_cache_init(
|
|
5820
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5821
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5822
|
-
}
|
|
5823
|
-
|
|
5824
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
5825
|
-
if (ir++ < ir0) continue;
|
|
5826
|
-
if (ir > ir1) break;
|
|
5827
|
-
|
|
5828
|
-
if (is_neox || is_mrope) {
|
|
5829
|
-
if (is_vision) {
|
|
5830
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5831
|
-
const int64_t ic = i0/2;
|
|
5832
|
-
|
|
5833
|
-
const float cos_theta = cache[i0 + 0];
|
|
5834
|
-
const float sin_theta = cache[i0 + 1];
|
|
5835
|
-
|
|
5836
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5837
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5838
|
-
|
|
5839
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5840
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
5841
|
-
|
|
5842
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5843
|
-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5844
|
-
}
|
|
5845
|
-
} else {
|
|
5846
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5847
|
-
const int64_t ic = i0/2;
|
|
5848
|
-
|
|
5849
|
-
const float cos_theta = cache[i0 + 0];
|
|
5850
|
-
const float sin_theta = cache[i0 + 1];
|
|
5851
|
-
|
|
5852
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5853
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5854
|
-
|
|
5855
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5856
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
|
5857
|
-
|
|
5858
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5859
|
-
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5860
|
-
}
|
|
5861
|
-
}
|
|
5862
|
-
} else {
|
|
5863
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5864
|
-
const float cos_theta = cache[i0 + 0];
|
|
5865
|
-
const float sin_theta = cache[i0 + 1];
|
|
5866
|
-
|
|
5867
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5868
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5869
|
-
|
|
5870
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5871
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
|
|
5872
|
-
|
|
5873
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5874
|
-
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5875
|
-
}
|
|
5876
|
-
}
|
|
5877
|
-
|
|
5878
|
-
if (is_vision) {
|
|
5879
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5880
|
-
const int64_t ic = i0/2;
|
|
5881
|
-
|
|
5882
|
-
const float cos_theta = cache[i0 + 0];
|
|
5883
|
-
const float sin_theta = cache[i0 + 1];
|
|
5884
|
-
|
|
5885
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5886
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5887
|
-
|
|
5888
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5889
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
5890
|
-
|
|
5891
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5892
|
-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5893
|
-
}
|
|
5894
|
-
} else {
|
|
5895
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5896
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5897
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5898
|
-
|
|
5899
|
-
dst_data[0] = src[0];
|
|
5900
|
-
dst_data[1] = src[1];
|
|
5901
|
-
}
|
|
5902
|
-
}
|
|
5903
|
-
}
|
|
5904
|
-
}
|
|
5905
|
-
}
|
|
5906
|
-
}
|
|
5907
|
-
|
|
5908
|
-
void ggml_compute_forward_rope(
|
|
5909
|
-
const ggml_compute_params * params,
|
|
5910
|
-
ggml_tensor * dst) {
|
|
5911
|
-
|
|
5912
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
5913
|
-
|
|
5914
|
-
switch (src0->type) {
|
|
5915
|
-
case GGML_TYPE_F16:
|
|
5916
|
-
{
|
|
5917
|
-
ggml_compute_forward_rope_f16(params, dst, true);
|
|
5918
|
-
} break;
|
|
5919
|
-
case GGML_TYPE_F32:
|
|
5920
|
-
{
|
|
5921
|
-
ggml_compute_forward_rope_f32(params, dst, true);
|
|
5922
|
-
} break;
|
|
5923
|
-
default:
|
|
5924
|
-
{
|
|
5925
|
-
GGML_ABORT("fatal error");
|
|
5926
|
-
}
|
|
5927
|
-
}
|
|
5928
|
-
}
|
|
5929
|
-
|
|
5930
|
-
// ggml_compute_forward_rope_back
|
|
5931
|
-
|
|
5932
|
-
void ggml_compute_forward_rope_back(
|
|
5933
|
-
const ggml_compute_params * params,
|
|
5934
|
-
ggml_tensor * dst) {
|
|
5935
|
-
|
|
5936
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
5937
|
-
|
|
5938
|
-
switch (src0->type) {
|
|
5939
|
-
case GGML_TYPE_F16:
|
|
5940
|
-
{
|
|
5941
|
-
ggml_compute_forward_rope_f16(params, dst, false);
|
|
5942
|
-
} break;
|
|
5943
|
-
case GGML_TYPE_F32:
|
|
5944
|
-
{
|
|
5945
|
-
ggml_compute_forward_rope_f32(params, dst, false);
|
|
5946
|
-
} break;
|
|
5947
|
-
default:
|
|
5948
|
-
{
|
|
5949
|
-
GGML_ABORT("fatal error");
|
|
5950
|
-
}
|
|
5951
|
-
}
|
|
5952
|
-
}
|
|
5953
|
-
|
|
5954
|
-
// ggml_compute_forward_conv_transpose_1d
|
|
5955
|
-
|
|
5956
|
-
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
|
5957
|
-
const ggml_compute_params * params,
|
|
5958
|
-
ggml_tensor * dst) {
|
|
5959
|
-
|
|
5960
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
5961
|
-
const ggml_tensor * src1 = dst->src[1];
|
|
5962
|
-
|
|
5963
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
5964
|
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
5965
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
5966
|
-
|
|
5967
|
-
GGML_TENSOR_BINARY_OP_LOCALS
|
|
5968
|
-
|
|
5969
|
-
const int ith = params->ith;
|
|
5970
|
-
const int nth = params->nth;
|
|
5971
|
-
|
|
5972
|
-
const int nk = ne00*ne01*ne02;
|
|
5973
|
-
|
|
5974
|
-
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
|
5975
|
-
GGML_ASSERT(nb10 == sizeof(float));
|
|
5976
|
-
|
|
5977
|
-
if (ith == 0) {
|
|
5978
|
-
memset(params->wdata, 0, params->wsize);
|
|
5979
|
-
|
|
5980
|
-
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
|
5981
|
-
{
|
|
5982
|
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
|
5870
|
+
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
|
5871
|
+
{
|
|
5872
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
|
5983
5873
|
|
|
5984
5874
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
5985
5875
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
@@ -6493,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16(
|
|
|
6493
6383
|
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6494
6384
|
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6495
6385
|
|
|
6496
|
-
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW
|
|
6386
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
6497
6387
|
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6498
6388
|
} else {
|
|
6499
6389
|
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
@@ -6664,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
|
|
|
6664
6554
|
ggml_compute_forward_mul_mat(params, &dst);
|
|
6665
6555
|
}
|
|
6666
6556
|
|
|
6557
|
+
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
|
6558
|
+
return (coord + size) % size; // adding size avoids negative number weirdness
|
|
6559
|
+
}
|
|
6560
|
+
|
|
6667
6561
|
// ggml_compute_forward_conv_2d
|
|
6668
6562
|
|
|
6563
|
+
|
|
6669
6564
|
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
|
6670
6565
|
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
|
6671
6566
|
const ggml_tensor * src, // [W, H, C, N]
|
|
@@ -7074,7 +6969,11 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
|
|
|
7074
6969
|
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
|
|
7075
6970
|
|
|
7076
6971
|
#ifdef GGML_SIMD
|
|
7077
|
-
|
|
6972
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
6973
|
+
const int64_t pkg_size = svcntw();
|
|
6974
|
+
#else
|
|
6975
|
+
const int64_t pkg_size = GGML_F32_EPR;
|
|
6976
|
+
#endif
|
|
7078
6977
|
const int64_t pkg_count = c / pkg_size;
|
|
7079
6978
|
const int64_t c_pkg_end = pkg_count * pkg_size;
|
|
7080
6979
|
#else
|
|
@@ -7497,10 +7396,17 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7497
7396
|
float sf1 = (float)ne1/src0->ne[1];
|
|
7498
7397
|
float sf2 = (float)ne2/src0->ne[2];
|
|
7499
7398
|
float sf3 = (float)ne3/src0->ne[3];
|
|
7399
|
+
float pixel_offset = 0.5f;
|
|
7500
7400
|
|
|
7501
7401
|
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
|
7502
7402
|
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
7503
7403
|
|
|
7404
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7405
|
+
pixel_offset = 0.0f;
|
|
7406
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
7407
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
7408
|
+
}
|
|
7409
|
+
|
|
7504
7410
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
7505
7411
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7506
7412
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7519,14 +7425,66 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7519
7425
|
}
|
|
7520
7426
|
}
|
|
7521
7427
|
}
|
|
7522
|
-
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7523
|
-
|
|
7524
|
-
|
|
7525
|
-
|
|
7526
|
-
|
|
7527
|
-
|
|
7528
|
-
|
|
7428
|
+
} else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
|
|
7429
|
+
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
|
7430
|
+
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
|
7431
|
+
auto triangle_filter = [](float x) -> float {
|
|
7432
|
+
return std::max(1.0f - fabsf(x), 0.0f);
|
|
7433
|
+
};
|
|
7434
|
+
|
|
7435
|
+
// support and invscale, minimum 1 pixel for bilinear
|
|
7436
|
+
const float support1 = std::max(1.0f, 1.0f / sf1);
|
|
7437
|
+
const float invscale1 = 1.0f / support1;
|
|
7438
|
+
const float support0 = std::max(1.0f, 1.0f / sf0);
|
|
7439
|
+
const float invscale0 = 1.0f / support0;
|
|
7529
7440
|
|
|
7441
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7442
|
+
const int64_t i03 = i3 / sf3;
|
|
7443
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7444
|
+
const int64_t i02 = i2 / sf2;
|
|
7445
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7446
|
+
const float y = ((float) i1 + pixel_offset) / sf1;
|
|
7447
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7448
|
+
const float x = ((float) i0 + pixel_offset) / sf0;
|
|
7449
|
+
|
|
7450
|
+
// the range of source pixels that contribute
|
|
7451
|
+
const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
|
|
7452
|
+
const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
|
|
7453
|
+
const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
|
|
7454
|
+
const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
|
|
7455
|
+
|
|
7456
|
+
// bilinear filter with antialiasing
|
|
7457
|
+
float val = 0.0f;
|
|
7458
|
+
float total_weight = 0.0f;
|
|
7459
|
+
|
|
7460
|
+
for (int64_t sy = y_min; sy < y_max; sy++) {
|
|
7461
|
+
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
|
7462
|
+
|
|
7463
|
+
for (int64_t sx = x_min; sx < x_max; sx++) {
|
|
7464
|
+
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
|
7465
|
+
const float weight = weight_x * weight_y;
|
|
7466
|
+
|
|
7467
|
+
if (weight <= 0.0f) {
|
|
7468
|
+
continue;
|
|
7469
|
+
}
|
|
7470
|
+
|
|
7471
|
+
const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
|
|
7472
|
+
val += pixel * weight;
|
|
7473
|
+
total_weight += weight;
|
|
7474
|
+
}
|
|
7475
|
+
}
|
|
7476
|
+
|
|
7477
|
+
if (total_weight > 0.0f) {
|
|
7478
|
+
val /= total_weight;
|
|
7479
|
+
}
|
|
7480
|
+
|
|
7481
|
+
float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7482
|
+
*dst_ptr = val;
|
|
7483
|
+
}
|
|
7484
|
+
}
|
|
7485
|
+
}
|
|
7486
|
+
}
|
|
7487
|
+
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7530
7488
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7531
7489
|
const int64_t i03 = i3 / sf3;
|
|
7532
7490
|
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
@@ -7561,6 +7519,51 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7561
7519
|
|
|
7562
7520
|
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
|
|
7563
7521
|
|
|
7522
|
+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7523
|
+
*y_dst = val;
|
|
7524
|
+
}
|
|
7525
|
+
}
|
|
7526
|
+
}
|
|
7527
|
+
}
|
|
7528
|
+
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
|
7529
|
+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
|
7530
|
+
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
|
7531
|
+
auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
|
|
7532
|
+
auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
|
|
7533
|
+
auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
|
|
7534
|
+
const float w0 = weight2(x + 1);
|
|
7535
|
+
const float w1 = weight1(x + 0);
|
|
7536
|
+
const float w2 = weight1(1 - x);
|
|
7537
|
+
const float w3 = weight2(2 - x);
|
|
7538
|
+
return p0*w0 + p1*w1 + p2*w2 + p3*w3;
|
|
7539
|
+
};
|
|
7540
|
+
|
|
7541
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7542
|
+
const int64_t i03 = i3 / sf3;
|
|
7543
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7544
|
+
const int64_t i02 = i2 / sf2;
|
|
7545
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7546
|
+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
|
|
7547
|
+
const int64_t y0 = (int64_t)floorf(y);
|
|
7548
|
+
const float dy = y - (float)y0;
|
|
7549
|
+
|
|
7550
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7551
|
+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
|
|
7552
|
+
const int64_t x0 = (int64_t)floorf(x);
|
|
7553
|
+
const float dx = x - (float)x0;
|
|
7554
|
+
|
|
7555
|
+
auto p = [=](int64_t x_off, int64_t y_off) -> float {
|
|
7556
|
+
int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
|
|
7557
|
+
int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
|
|
7558
|
+
return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
7559
|
+
};
|
|
7560
|
+
|
|
7561
|
+
const float val = bicubic(
|
|
7562
|
+
bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
|
|
7563
|
+
bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
|
|
7564
|
+
bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
|
|
7565
|
+
bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
|
|
7566
|
+
|
|
7564
7567
|
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7565
7568
|
*y_dst = val;
|
|
7566
7569
|
}
|
|
@@ -7593,6 +7596,7 @@ void ggml_compute_forward_upscale(
|
|
|
7593
7596
|
|
|
7594
7597
|
// ggml_compute_forward_pad
|
|
7595
7598
|
|
|
7599
|
+
template<bool circular_t>
|
|
7596
7600
|
static void ggml_compute_forward_pad_f32(
|
|
7597
7601
|
const ggml_compute_params * params,
|
|
7598
7602
|
ggml_tensor * dst) {
|
|
@@ -7617,23 +7621,40 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7617
7621
|
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
|
7618
7622
|
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
|
7619
7623
|
|
|
7620
|
-
|
|
7621
7624
|
// TODO: optimize
|
|
7622
7625
|
|
|
7623
7626
|
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
|
7624
7627
|
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
|
7625
7628
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
|
7626
7629
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
7627
|
-
|
|
7628
|
-
if (
|
|
7629
|
-
|
|
7630
|
-
|
|
7631
|
-
|
|
7632
|
-
const int64_t
|
|
7630
|
+
// circular means wrap around on a torus, so x and y loop around
|
|
7631
|
+
if constexpr (circular_t) {
|
|
7632
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7633
|
+
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
|
|
7634
|
+
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
|
|
7635
|
+
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
|
|
7636
|
+
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
|
|
7637
|
+
|
|
7638
|
+
const int64_t src_idx =
|
|
7639
|
+
src_i3*nb03 +
|
|
7640
|
+
src_i2*nb02 +
|
|
7641
|
+
src_i1*nb01 +
|
|
7642
|
+
src_i0*nb00;
|
|
7643
|
+
|
|
7633
7644
|
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7634
7645
|
dst_ptr[dst_idx] = *src_ptr;
|
|
7635
7646
|
} else {
|
|
7636
|
-
|
|
7647
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7648
|
+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
|
7649
|
+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
|
7650
|
+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
|
7651
|
+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
|
7652
|
+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
|
7653
|
+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7654
|
+
dst_ptr[dst_idx] = *src_ptr;
|
|
7655
|
+
} else {
|
|
7656
|
+
dst_ptr[dst_idx] = 0;
|
|
7657
|
+
}
|
|
7637
7658
|
}
|
|
7638
7659
|
}
|
|
7639
7660
|
}
|
|
@@ -7641,16 +7662,20 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7641
7662
|
}
|
|
7642
7663
|
}
|
|
7643
7664
|
|
|
7665
|
+
|
|
7644
7666
|
void ggml_compute_forward_pad(
|
|
7645
7667
|
const ggml_compute_params * params,
|
|
7646
7668
|
ggml_tensor * dst) {
|
|
7647
|
-
|
|
7648
7669
|
const ggml_tensor * src0 = dst->src[0];
|
|
7649
|
-
|
|
7670
|
+
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
|
|
7650
7671
|
switch (src0->type) {
|
|
7651
7672
|
case GGML_TYPE_F32:
|
|
7652
7673
|
{
|
|
7653
|
-
|
|
7674
|
+
if (circular) {
|
|
7675
|
+
ggml_compute_forward_pad_f32<true>(params, dst);
|
|
7676
|
+
} else {
|
|
7677
|
+
ggml_compute_forward_pad_f32<false>(params, dst);
|
|
7678
|
+
}
|
|
7654
7679
|
} break;
|
|
7655
7680
|
default:
|
|
7656
7681
|
{
|
|
@@ -7854,6 +7879,18 @@ void ggml_compute_forward_timestep_embedding(
|
|
|
7854
7879
|
|
|
7855
7880
|
// ggml_compute_forward_argsort
|
|
7856
7881
|
|
|
7882
|
+
template<enum ggml_sort_order order>
|
|
7883
|
+
struct cmp_argsort {
|
|
7884
|
+
const float * data;
|
|
7885
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7886
|
+
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
|
7887
|
+
return data[a] < data[b];
|
|
7888
|
+
} else {
|
|
7889
|
+
return data[a] > data[b];
|
|
7890
|
+
}
|
|
7891
|
+
}
|
|
7892
|
+
};
|
|
7893
|
+
|
|
7857
7894
|
static void ggml_compute_forward_argsort_f32(
|
|
7858
7895
|
const ggml_compute_params * params,
|
|
7859
7896
|
ggml_tensor * dst) {
|
|
@@ -7872,23 +7909,25 @@ static void ggml_compute_forward_argsort_f32(
|
|
|
7872
7909
|
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
|
7873
7910
|
|
|
7874
7911
|
for (int64_t i = ith; i < nr; i += nth) {
|
|
7875
|
-
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7876
7912
|
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
7877
7913
|
|
|
7914
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7915
|
+
|
|
7878
7916
|
for (int64_t j = 0; j < ne0; j++) {
|
|
7879
7917
|
dst_data[j] = j;
|
|
7880
7918
|
}
|
|
7881
7919
|
|
|
7882
|
-
|
|
7883
|
-
|
|
7884
|
-
|
|
7885
|
-
|
|
7886
|
-
|
|
7887
|
-
|
|
7888
|
-
|
|
7889
|
-
|
|
7890
|
-
|
|
7891
|
-
|
|
7920
|
+
switch (order) {
|
|
7921
|
+
case GGML_SORT_ORDER_ASC:
|
|
7922
|
+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
|
|
7923
|
+
break;
|
|
7924
|
+
|
|
7925
|
+
case GGML_SORT_ORDER_DESC:
|
|
7926
|
+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
|
|
7927
|
+
break;
|
|
7928
|
+
|
|
7929
|
+
default:
|
|
7930
|
+
GGML_ABORT("invalid sort order");
|
|
7892
7931
|
}
|
|
7893
7932
|
}
|
|
7894
7933
|
}
|
|
@@ -7911,12 +7950,78 @@ void ggml_compute_forward_argsort(
|
|
|
7911
7950
|
}
|
|
7912
7951
|
}
|
|
7913
7952
|
|
|
7953
|
+
// ggml_compute_forward_top_k
|
|
7954
|
+
|
|
7955
|
+
struct cmp_top_k {
|
|
7956
|
+
const float * data;
|
|
7957
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7958
|
+
return data[a] > data[b];
|
|
7959
|
+
}
|
|
7960
|
+
};
|
|
7961
|
+
|
|
7962
|
+
static void ggml_compute_forward_top_k_f32(
|
|
7963
|
+
const ggml_compute_params * params,
|
|
7964
|
+
ggml_tensor * dst) {
|
|
7965
|
+
|
|
7966
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7967
|
+
|
|
7968
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
7969
|
+
|
|
7970
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
7971
|
+
|
|
7972
|
+
const int ith = params->ith;
|
|
7973
|
+
const int nth = params->nth;
|
|
7974
|
+
|
|
7975
|
+
const int64_t nr = ggml_nrows(src0);
|
|
7976
|
+
|
|
7977
|
+
const int top_k = ne0;
|
|
7978
|
+
|
|
7979
|
+
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
7980
|
+
|
|
7981
|
+
for (int64_t i = ith; i < nr; i += nth) {
|
|
7982
|
+
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
7983
|
+
|
|
7984
|
+
for (int64_t j = 0; j < ne00; j++) {
|
|
7985
|
+
tmp[j] = j;
|
|
7986
|
+
}
|
|
7987
|
+
|
|
7988
|
+
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
|
|
7989
|
+
|
|
7990
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7991
|
+
|
|
7992
|
+
std::copy(tmp, tmp + top_k, dst_data);
|
|
7993
|
+
|
|
7994
|
+
// emphasize that the order is not important
|
|
7995
|
+
if (top_k > 1) {
|
|
7996
|
+
std::swap(dst_data[0], dst_data[1]);
|
|
7997
|
+
}
|
|
7998
|
+
}
|
|
7999
|
+
}
|
|
8000
|
+
|
|
8001
|
+
void ggml_compute_forward_top_k(
|
|
8002
|
+
const ggml_compute_params * params,
|
|
8003
|
+
ggml_tensor * dst) {
|
|
8004
|
+
|
|
8005
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
8006
|
+
|
|
8007
|
+
switch (src0->type) {
|
|
8008
|
+
case GGML_TYPE_F32:
|
|
8009
|
+
{
|
|
8010
|
+
ggml_compute_forward_top_k_f32(params, dst);
|
|
8011
|
+
} break;
|
|
8012
|
+
default:
|
|
8013
|
+
{
|
|
8014
|
+
GGML_ABORT("fatal error");
|
|
8015
|
+
}
|
|
8016
|
+
}
|
|
8017
|
+
}
|
|
8018
|
+
|
|
7914
8019
|
// ggml_compute_forward_flash_attn_ext
|
|
7915
8020
|
|
|
7916
|
-
static void
|
|
8021
|
+
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
7917
8022
|
const ggml_compute_params * params,
|
|
7918
|
-
ggml_tensor * dst
|
|
7919
|
-
|
|
8023
|
+
ggml_tensor * dst,
|
|
8024
|
+
int ir0, int ir1) {
|
|
7920
8025
|
const ggml_tensor * q = dst->src[0];
|
|
7921
8026
|
const ggml_tensor * k = dst->src[1];
|
|
7922
8027
|
const ggml_tensor * v = dst->src[2];
|
|
@@ -7932,9 +8037,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7932
8037
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
7933
8038
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
7934
8039
|
|
|
7935
|
-
const int ith = params->ith;
|
|
7936
|
-
const int nth = params->nth;
|
|
7937
|
-
|
|
7938
8040
|
const int64_t DK = nek0;
|
|
7939
8041
|
const int64_t DV = nev0;
|
|
7940
8042
|
const int64_t N = neq1;
|
|
@@ -7968,16 +8070,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7968
8070
|
|
|
7969
8071
|
// parallelize by q rows using ggml_vec_dot_f32
|
|
7970
8072
|
|
|
7971
|
-
// total rows in q
|
|
7972
|
-
const int nr = neq1*neq2*neq3;
|
|
7973
|
-
|
|
7974
|
-
// rows per thread
|
|
7975
|
-
const int dr = (nr + nth - 1)/nth;
|
|
7976
|
-
|
|
7977
|
-
// row range for this thread
|
|
7978
|
-
const int ir0 = dr*ith;
|
|
7979
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
7980
|
-
|
|
7981
8073
|
float scale = 1.0f;
|
|
7982
8074
|
float max_bias = 0.0f;
|
|
7983
8075
|
float logit_softcap = 0.0f;
|
|
@@ -8004,6 +8096,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8004
8096
|
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
8005
8097
|
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
8006
8098
|
|
|
8099
|
+
int ith = params->ith;
|
|
8100
|
+
|
|
8007
8101
|
// loop over n_batch and n_head
|
|
8008
8102
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
8009
8103
|
// q indices
|
|
@@ -8135,7 +8229,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8135
8229
|
}
|
|
8136
8230
|
|
|
8137
8231
|
// V /= S
|
|
8138
|
-
const float S_inv = 1.0f/S;
|
|
8232
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
8139
8233
|
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
8140
8234
|
|
|
8141
8235
|
// dst indices
|
|
@@ -8151,6 +8245,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8151
8245
|
}
|
|
8152
8246
|
}
|
|
8153
8247
|
|
|
8248
|
+
static void ggml_compute_forward_flash_attn_ext_f16(
|
|
8249
|
+
const ggml_compute_params * params,
|
|
8250
|
+
ggml_tensor * dst) {
|
|
8251
|
+
|
|
8252
|
+
const ggml_tensor * q = dst->src[0];
|
|
8253
|
+
const ggml_tensor * k = dst->src[1];
|
|
8254
|
+
const ggml_tensor * v = dst->src[2];
|
|
8255
|
+
|
|
8256
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8257
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8258
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8259
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8260
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8261
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8262
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8263
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8264
|
+
|
|
8265
|
+
const int64_t DK = nek0;
|
|
8266
|
+
const int64_t DV = nev0;
|
|
8267
|
+
const int64_t N = neq1;
|
|
8268
|
+
|
|
8269
|
+
GGML_ASSERT(ne0 == DV);
|
|
8270
|
+
GGML_ASSERT(ne2 == N);
|
|
8271
|
+
|
|
8272
|
+
// input tensor rows must be contiguous
|
|
8273
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
8274
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
8275
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
8276
|
+
|
|
8277
|
+
GGML_ASSERT(neq0 == DK);
|
|
8278
|
+
GGML_ASSERT(nek0 == DK);
|
|
8279
|
+
GGML_ASSERT(nev0 == DV);
|
|
8280
|
+
|
|
8281
|
+
GGML_ASSERT(neq1 == N);
|
|
8282
|
+
|
|
8283
|
+
// dst cannot be transposed or permuted
|
|
8284
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8285
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
8286
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
8287
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
8288
|
+
|
|
8289
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
|
8290
|
+
|
|
8291
|
+
// total rows in q
|
|
8292
|
+
const int64_t nr = neq1*neq2*neq3;
|
|
8293
|
+
|
|
8294
|
+
// rows per thread
|
|
8295
|
+
const int ith = params->ith;
|
|
8296
|
+
const int nth = params->nth;
|
|
8297
|
+
|
|
8298
|
+
// disable for NUMA
|
|
8299
|
+
const bool disable_chunking = ggml_is_numa();
|
|
8300
|
+
|
|
8301
|
+
// 4x chunks per thread
|
|
8302
|
+
int nth_scaled = nth * 4;
|
|
8303
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8304
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8305
|
+
|
|
8306
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
8307
|
+
nchunk = nth;
|
|
8308
|
+
}
|
|
8309
|
+
|
|
8310
|
+
if (ith == 0) {
|
|
8311
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
8312
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
8313
|
+
}
|
|
8314
|
+
|
|
8315
|
+
ggml_barrier(params->threadpool);
|
|
8316
|
+
|
|
8317
|
+
// The number of elements in each chunk
|
|
8318
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
8319
|
+
|
|
8320
|
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
8321
|
+
int current_chunk = ith;
|
|
8322
|
+
|
|
8323
|
+
while (current_chunk < nchunk) {
|
|
8324
|
+
const int64_t ir0 = dr * current_chunk;
|
|
8325
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
8326
|
+
|
|
8327
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
|
8328
|
+
|
|
8329
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
8330
|
+
}
|
|
8331
|
+
}
|
|
8332
|
+
|
|
8154
8333
|
void ggml_compute_forward_flash_attn_ext(
|
|
8155
8334
|
const ggml_compute_params * params,
|
|
8156
8335
|
ggml_tensor * dst) {
|
|
@@ -8637,7 +8816,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8637
8816
|
// n_head
|
|
8638
8817
|
for (int h = ih0; h < ih1; ++h) {
|
|
8639
8818
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8640
|
-
const float dt_soft_plus =
|
|
8819
|
+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
|
8641
8820
|
const float dA = expf(dt_soft_plus * A[h]);
|
|
8642
8821
|
const int g = h / (nh / ng); // repeat_interleave
|
|
8643
8822
|
|
|
@@ -8734,7 +8913,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8734
8913
|
// n_head
|
|
8735
8914
|
for (int h = ih0; h < ih1; ++h) {
|
|
8736
8915
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8737
|
-
const float dt_soft_plus =
|
|
8916
|
+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
|
8738
8917
|
const int g = h / (nh / ng); // repeat_interleave
|
|
8739
8918
|
|
|
8740
8919
|
// dim
|
|
@@ -8997,6 +9176,34 @@ void ggml_compute_forward_unary(
|
|
|
8997
9176
|
{
|
|
8998
9177
|
ggml_compute_forward_exp(params, dst);
|
|
8999
9178
|
} break;
|
|
9179
|
+
case GGML_UNARY_OP_FLOOR:
|
|
9180
|
+
{
|
|
9181
|
+
ggml_compute_forward_floor(params, dst);
|
|
9182
|
+
} break;
|
|
9183
|
+
case GGML_UNARY_OP_CEIL:
|
|
9184
|
+
{
|
|
9185
|
+
ggml_compute_forward_ceil(params, dst);
|
|
9186
|
+
} break;
|
|
9187
|
+
case GGML_UNARY_OP_ROUND:
|
|
9188
|
+
{
|
|
9189
|
+
ggml_compute_forward_round(params, dst);
|
|
9190
|
+
} break;
|
|
9191
|
+
case GGML_UNARY_OP_TRUNC:
|
|
9192
|
+
{
|
|
9193
|
+
ggml_compute_forward_trunc(params, dst);
|
|
9194
|
+
} break;
|
|
9195
|
+
case GGML_UNARY_OP_XIELU:
|
|
9196
|
+
{
|
|
9197
|
+
ggml_compute_forward_xielu(params, dst);
|
|
9198
|
+
} break;
|
|
9199
|
+
case GGML_UNARY_OP_EXPM1:
|
|
9200
|
+
{
|
|
9201
|
+
ggml_compute_forward_expm1(params, dst);
|
|
9202
|
+
} break;
|
|
9203
|
+
case GGML_UNARY_OP_SOFTPLUS:
|
|
9204
|
+
{
|
|
9205
|
+
ggml_compute_forward_softplus(params, dst);
|
|
9206
|
+
} break;
|
|
9000
9207
|
default:
|
|
9001
9208
|
{
|
|
9002
9209
|
GGML_ABORT("fatal error");
|
|
@@ -9593,6 +9800,76 @@ void ggml_compute_forward_gla(
|
|
|
9593
9800
|
}
|
|
9594
9801
|
}
|
|
9595
9802
|
|
|
9803
|
+
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
9804
|
+
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
|
9805
|
+
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
|
9806
|
+
|
|
9807
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
9808
|
+
|
|
9809
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
9810
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
9811
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
9812
|
+
|
|
9813
|
+
GGML_ASSERT(ne00 == ne01); // A must be square
|
|
9814
|
+
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
|
9815
|
+
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
|
9816
|
+
|
|
9817
|
+
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
|
9818
|
+
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
|
9819
|
+
|
|
9820
|
+
const int ith = params->ith;
|
|
9821
|
+
const int nth = params->nth;
|
|
9822
|
+
|
|
9823
|
+
const int64_t k = ne10; // number of RHS columns
|
|
9824
|
+
const int64_t n = ne11; // A is n×n
|
|
9825
|
+
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
|
9826
|
+
|
|
9827
|
+
// chunks per thread
|
|
9828
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
|
9829
|
+
|
|
9830
|
+
// chunk range for this thread
|
|
9831
|
+
const int64_t ir0 = dr*ith;
|
|
9832
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
9833
|
+
|
|
9834
|
+
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
|
9835
|
+
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
|
9836
|
+
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
|
9837
|
+
|
|
9838
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
9839
|
+
const int64_t i03 = ir/(ne02*k);
|
|
9840
|
+
const int64_t i02 = (ir - i03*ne02*k)/k;
|
|
9841
|
+
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
|
9842
|
+
|
|
9843
|
+
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
|
9844
|
+
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
|
9845
|
+
|
|
9846
|
+
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
|
9847
|
+
|
|
9848
|
+
for (int64_t i00 = 0; i00 < n; ++i00) {
|
|
9849
|
+
float sum = 0.0f;
|
|
9850
|
+
for (int64_t t = 0; t < i00; ++t) {
|
|
9851
|
+
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
|
9852
|
+
}
|
|
9853
|
+
|
|
9854
|
+
const float diag = A_batch[i00 * n + i00];
|
|
9855
|
+
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
|
|
9856
|
+
|
|
9857
|
+
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
|
9858
|
+
}
|
|
9859
|
+
}
|
|
9860
|
+
}
|
|
9861
|
+
|
|
9862
|
+
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
9863
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
9864
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
9865
|
+
|
|
9866
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
9867
|
+
ggml_compute_forward_solve_tri_f32(params, dst);
|
|
9868
|
+
} else {
|
|
9869
|
+
GGML_ABORT("fatal error");
|
|
9870
|
+
}
|
|
9871
|
+
}
|
|
9872
|
+
|
|
9596
9873
|
// ggml_compute_forward_rwkv_wkv7
|
|
9597
9874
|
|
|
9598
9875
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|