whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -7,7 +7,10 @@
|
|
|
7
7
|
#include "unary-ops.h"
|
|
8
8
|
#include "vec.h"
|
|
9
9
|
|
|
10
|
-
#include <
|
|
10
|
+
#include <cfloat>
|
|
11
|
+
#include <algorithm>
|
|
12
|
+
#include <cmath>
|
|
13
|
+
#include <functional>
|
|
11
14
|
|
|
12
15
|
// ggml_compute_forward_dup
|
|
13
16
|
|
|
@@ -40,13 +43,15 @@ static void ggml_compute_forward_dup_same_cont(
|
|
|
40
43
|
}
|
|
41
44
|
}
|
|
42
45
|
|
|
43
|
-
|
|
46
|
+
template<typename src_t, typename dst_t>
|
|
47
|
+
static void ggml_compute_forward_dup_flt(
|
|
44
48
|
const ggml_compute_params * params,
|
|
45
49
|
ggml_tensor * dst) {
|
|
46
50
|
|
|
47
51
|
const ggml_tensor * src0 = dst->src[0];
|
|
48
52
|
|
|
49
53
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
54
|
+
GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
|
|
50
55
|
|
|
51
56
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
52
57
|
|
|
@@ -61,6 +66,7 @@ static void ggml_compute_forward_dup_f16(
|
|
|
61
66
|
const int ir0 = dr * ith;
|
|
62
67
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
63
68
|
|
|
69
|
+
// case: type & row size equal
|
|
64
70
|
if (src0->type == dst->type &&
|
|
65
71
|
ne00 == ne0 &&
|
|
66
72
|
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
|
@@ -79,275 +85,11 @@ static void ggml_compute_forward_dup_f16(
|
|
|
79
85
|
return;
|
|
80
86
|
}
|
|
81
87
|
|
|
82
|
-
//
|
|
83
|
-
|
|
84
|
-
if (ggml_is_contiguous(dst)) {
|
|
85
|
-
if (nb00 == sizeof(ggml_fp16_t)) {
|
|
86
|
-
if (dst->type == GGML_TYPE_F16) {
|
|
87
|
-
size_t id = 0;
|
|
88
|
-
const size_t rs = ne00 * nb00;
|
|
89
|
-
char * dst_ptr = (char *) dst->data;
|
|
90
|
-
|
|
91
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
92
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
93
|
-
id += rs * ir0;
|
|
94
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
95
|
-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
|
96
|
-
memcpy(dst_ptr + id, src0_ptr, rs);
|
|
97
|
-
id += rs;
|
|
98
|
-
}
|
|
99
|
-
id += rs * (ne01 - ir1);
|
|
100
|
-
}
|
|
101
|
-
}
|
|
102
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
|
103
|
-
size_t id = 0;
|
|
104
|
-
float * dst_ptr = (float *) dst->data;
|
|
105
|
-
|
|
106
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
107
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
108
|
-
id += ne00 * ir0;
|
|
109
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
110
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
111
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
112
|
-
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
|
113
|
-
id++;
|
|
114
|
-
}
|
|
115
|
-
}
|
|
116
|
-
id += ne00 * (ne01 - ir1);
|
|
117
|
-
}
|
|
118
|
-
}
|
|
119
|
-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
120
|
-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
121
|
-
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
122
|
-
|
|
123
|
-
size_t id = 0;
|
|
124
|
-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
125
|
-
char * dst_ptr = (char *) dst->data;
|
|
126
|
-
|
|
127
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
128
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
129
|
-
id += rs * ir0;
|
|
130
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
131
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
132
|
-
|
|
133
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
134
|
-
src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
138
|
-
id += rs;
|
|
139
|
-
}
|
|
140
|
-
id += rs * (ne01 - ir1);
|
|
141
|
-
}
|
|
142
|
-
}
|
|
143
|
-
} else {
|
|
144
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
145
|
-
}
|
|
146
|
-
} else {
|
|
147
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
148
|
-
|
|
149
|
-
if (dst->type == GGML_TYPE_F32) {
|
|
150
|
-
size_t id = 0;
|
|
151
|
-
float * dst_ptr = (float *) dst->data;
|
|
152
|
-
|
|
153
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
154
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
155
|
-
id += ne00 * ir0;
|
|
156
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
157
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
158
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
159
|
-
|
|
160
|
-
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
|
|
161
|
-
id++;
|
|
162
|
-
}
|
|
163
|
-
}
|
|
164
|
-
id += ne00 * (ne01 - ir1);
|
|
165
|
-
}
|
|
166
|
-
}
|
|
167
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
|
168
|
-
size_t id = 0;
|
|
169
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
170
|
-
|
|
171
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
172
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
173
|
-
id += ne00 * ir0;
|
|
174
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
175
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
176
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
177
|
-
|
|
178
|
-
dst_ptr[id] = *src0_ptr;
|
|
179
|
-
id++;
|
|
180
|
-
}
|
|
181
|
-
}
|
|
182
|
-
id += ne00 * (ne01 - ir1);
|
|
183
|
-
}
|
|
184
|
-
}
|
|
185
|
-
} else {
|
|
186
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
187
|
-
}
|
|
188
|
-
}
|
|
189
|
-
return;
|
|
190
|
-
}
|
|
191
|
-
|
|
192
|
-
// dst counters
|
|
193
|
-
int64_t i10 = 0;
|
|
194
|
-
int64_t i11 = 0;
|
|
195
|
-
int64_t i12 = 0;
|
|
196
|
-
int64_t i13 = 0;
|
|
197
|
-
|
|
198
|
-
if (dst->type == GGML_TYPE_F16) {
|
|
199
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
200
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
201
|
-
i10 += ne00 * ir0;
|
|
202
|
-
while (i10 >= ne0) {
|
|
203
|
-
i10 -= ne0;
|
|
204
|
-
if (++i11 == ne1) {
|
|
205
|
-
i11 = 0;
|
|
206
|
-
if (++i12 == ne2) {
|
|
207
|
-
i12 = 0;
|
|
208
|
-
if (++i13 == ne3) {
|
|
209
|
-
i13 = 0;
|
|
210
|
-
}
|
|
211
|
-
}
|
|
212
|
-
}
|
|
213
|
-
}
|
|
214
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
215
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
216
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
217
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
218
|
-
|
|
219
|
-
memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
|
|
220
|
-
|
|
221
|
-
if (++i10 == ne00) {
|
|
222
|
-
i10 = 0;
|
|
223
|
-
if (++i11 == ne01) {
|
|
224
|
-
i11 = 0;
|
|
225
|
-
if (++i12 == ne02) {
|
|
226
|
-
i12 = 0;
|
|
227
|
-
if (++i13 == ne03) {
|
|
228
|
-
i13 = 0;
|
|
229
|
-
}
|
|
230
|
-
}
|
|
231
|
-
}
|
|
232
|
-
}
|
|
233
|
-
}
|
|
234
|
-
}
|
|
235
|
-
i10 += ne00 * (ne01 - ir1);
|
|
236
|
-
while (i10 >= ne0) {
|
|
237
|
-
i10 -= ne0;
|
|
238
|
-
if (++i11 == ne1) {
|
|
239
|
-
i11 = 0;
|
|
240
|
-
if (++i12 == ne2) {
|
|
241
|
-
i12 = 0;
|
|
242
|
-
if (++i13 == ne3) {
|
|
243
|
-
i13 = 0;
|
|
244
|
-
}
|
|
245
|
-
}
|
|
246
|
-
}
|
|
247
|
-
}
|
|
248
|
-
}
|
|
249
|
-
}
|
|
250
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
|
251
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
252
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
253
|
-
i10 += ne00 * ir0;
|
|
254
|
-
while (i10 >= ne0) {
|
|
255
|
-
i10 -= ne0;
|
|
256
|
-
if (++i11 == ne1) {
|
|
257
|
-
i11 = 0;
|
|
258
|
-
if (++i12 == ne2) {
|
|
259
|
-
i12 = 0;
|
|
260
|
-
if (++i13 == ne3) {
|
|
261
|
-
i13 = 0;
|
|
262
|
-
}
|
|
263
|
-
}
|
|
264
|
-
}
|
|
265
|
-
}
|
|
266
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
267
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
268
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
269
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
270
|
-
|
|
271
|
-
*(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
|
272
|
-
|
|
273
|
-
if (++i10 == ne0) {
|
|
274
|
-
i10 = 0;
|
|
275
|
-
if (++i11 == ne1) {
|
|
276
|
-
i11 = 0;
|
|
277
|
-
if (++i12 == ne2) {
|
|
278
|
-
i12 = 0;
|
|
279
|
-
if (++i13 == ne3) {
|
|
280
|
-
i13 = 0;
|
|
281
|
-
}
|
|
282
|
-
}
|
|
283
|
-
}
|
|
284
|
-
}
|
|
285
|
-
}
|
|
286
|
-
}
|
|
287
|
-
i10 += ne00 * (ne01 - ir1);
|
|
288
|
-
while (i10 >= ne0) {
|
|
289
|
-
i10 -= ne0;
|
|
290
|
-
if (++i11 == ne1) {
|
|
291
|
-
i11 = 0;
|
|
292
|
-
if (++i12 == ne2) {
|
|
293
|
-
i12 = 0;
|
|
294
|
-
if (++i13 == ne3) {
|
|
295
|
-
i13 = 0;
|
|
296
|
-
}
|
|
297
|
-
}
|
|
298
|
-
}
|
|
299
|
-
}
|
|
300
|
-
}
|
|
301
|
-
}
|
|
302
|
-
} else {
|
|
303
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
304
|
-
}
|
|
305
|
-
}
|
|
306
|
-
|
|
307
|
-
static void ggml_compute_forward_dup_bf16(
|
|
308
|
-
const ggml_compute_params * params,
|
|
309
|
-
ggml_tensor * dst) {
|
|
310
|
-
|
|
311
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
312
|
-
|
|
313
|
-
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
314
|
-
|
|
315
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
|
316
|
-
|
|
317
|
-
const int ith = params->ith; // thread index
|
|
318
|
-
const int nth = params->nth; // number of threads
|
|
319
|
-
|
|
320
|
-
// parallelize by rows
|
|
321
|
-
const int nr = ne01;
|
|
322
|
-
// number of rows per thread
|
|
323
|
-
const int dr = (nr + nth - 1) / nth;
|
|
324
|
-
// row range for this thread
|
|
325
|
-
const int ir0 = dr * ith;
|
|
326
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
327
|
-
|
|
328
|
-
if (src0->type == dst->type &&
|
|
329
|
-
ne00 == ne0 &&
|
|
330
|
-
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
|
331
|
-
// copy by rows
|
|
332
|
-
const size_t rs = ne00*nb00;
|
|
333
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
334
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
335
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
336
|
-
memcpy(
|
|
337
|
-
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
|
338
|
-
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
|
339
|
-
rs);
|
|
340
|
-
}
|
|
341
|
-
}
|
|
342
|
-
}
|
|
343
|
-
return;
|
|
344
|
-
}
|
|
345
|
-
|
|
346
|
-
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
|
347
|
-
|
|
88
|
+
// case: dst tensor is contiguous
|
|
348
89
|
if (ggml_is_contiguous(dst)) {
|
|
349
|
-
if (nb00 == sizeof(
|
|
350
|
-
if (
|
|
90
|
+
if (nb00 == sizeof(src_t)) {
|
|
91
|
+
if constexpr (std::is_same_v<dst_t, src_t>) {
|
|
92
|
+
// same type
|
|
351
93
|
size_t id = 0;
|
|
352
94
|
const size_t rs = ne00 * nb00;
|
|
353
95
|
char * dst_ptr = (char *) dst->data;
|
|
@@ -363,434 +105,58 @@ static void ggml_compute_forward_dup_bf16(
|
|
|
363
105
|
id += rs * (ne01 - ir1);
|
|
364
106
|
}
|
|
365
107
|
}
|
|
366
|
-
} else
|
|
367
|
-
|
|
368
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
369
|
-
|
|
370
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
371
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
372
|
-
id += ne00 * ir0;
|
|
373
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
374
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
375
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
376
|
-
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
|
|
377
|
-
id++;
|
|
378
|
-
}
|
|
379
|
-
}
|
|
380
|
-
id += ne00 * (ne01 - ir1);
|
|
381
|
-
}
|
|
382
|
-
}
|
|
383
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
|
384
|
-
size_t id = 0;
|
|
385
|
-
float * dst_ptr = (float *) dst->data;
|
|
386
|
-
|
|
387
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
388
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
389
|
-
id += ne00 * ir0;
|
|
390
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
391
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
392
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
393
|
-
dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
|
|
394
|
-
id++;
|
|
395
|
-
}
|
|
396
|
-
}
|
|
397
|
-
id += ne00 * (ne01 - ir1);
|
|
398
|
-
}
|
|
399
|
-
}
|
|
400
|
-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
401
|
-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
402
|
-
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
403
|
-
|
|
404
|
-
size_t id = 0;
|
|
405
|
-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
406
|
-
char * dst_ptr = (char *) dst->data;
|
|
407
|
-
|
|
408
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
409
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
410
|
-
id += rs * ir0;
|
|
411
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
412
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
413
|
-
|
|
414
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
415
|
-
src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
|
|
416
|
-
}
|
|
417
|
-
|
|
418
|
-
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
419
|
-
id += rs;
|
|
420
|
-
}
|
|
421
|
-
id += rs * (ne01 - ir1);
|
|
422
|
-
}
|
|
423
|
-
}
|
|
424
|
-
} else {
|
|
425
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
426
|
-
}
|
|
427
|
-
} else {
|
|
428
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
429
|
-
|
|
430
|
-
if (dst->type == GGML_TYPE_F32) {
|
|
431
|
-
size_t id = 0;
|
|
432
|
-
float * dst_ptr = (float *) dst->data;
|
|
433
|
-
|
|
434
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
435
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
436
|
-
id += ne00 * ir0;
|
|
437
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
438
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
439
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
440
|
-
|
|
441
|
-
dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
|
|
442
|
-
id++;
|
|
443
|
-
}
|
|
444
|
-
}
|
|
445
|
-
id += ne00 * (ne01 - ir1);
|
|
446
|
-
}
|
|
447
|
-
}
|
|
448
|
-
} else if (dst->type == GGML_TYPE_BF16) {
|
|
449
|
-
size_t id = 0;
|
|
450
|
-
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
|
|
451
|
-
|
|
452
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
453
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
454
|
-
id += ne00 * ir0;
|
|
455
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
456
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
457
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
458
|
-
|
|
459
|
-
dst_ptr[id] = *src0_ptr;
|
|
460
|
-
id++;
|
|
461
|
-
}
|
|
462
|
-
}
|
|
463
|
-
id += ne00 * (ne01 - ir1);
|
|
464
|
-
}
|
|
465
|
-
}
|
|
466
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
|
467
|
-
size_t id = 0;
|
|
468
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
469
|
-
|
|
470
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
471
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
472
|
-
id += ne00 * ir0;
|
|
473
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
474
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
475
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
476
|
-
|
|
477
|
-
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
|
|
478
|
-
id++;
|
|
479
|
-
}
|
|
480
|
-
}
|
|
481
|
-
id += ne00 * (ne01 - ir1);
|
|
482
|
-
}
|
|
483
|
-
}
|
|
484
|
-
} else {
|
|
485
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
486
|
-
}
|
|
487
|
-
}
|
|
488
|
-
return;
|
|
489
|
-
}
|
|
490
|
-
|
|
491
|
-
// dst counters
|
|
492
|
-
int64_t i10 = 0;
|
|
493
|
-
int64_t i11 = 0;
|
|
494
|
-
int64_t i12 = 0;
|
|
495
|
-
int64_t i13 = 0;
|
|
496
|
-
|
|
497
|
-
if (dst->type == GGML_TYPE_BF16) {
|
|
498
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
499
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
500
|
-
i10 += ne00 * ir0;
|
|
501
|
-
while (i10 >= ne0) {
|
|
502
|
-
i10 -= ne0;
|
|
503
|
-
if (++i11 == ne1) {
|
|
504
|
-
i11 = 0;
|
|
505
|
-
if (++i12 == ne2) {
|
|
506
|
-
i12 = 0;
|
|
507
|
-
if (++i13 == ne3) {
|
|
508
|
-
i13 = 0;
|
|
509
|
-
}
|
|
510
|
-
}
|
|
511
|
-
}
|
|
512
|
-
}
|
|
513
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
514
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
515
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
516
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
517
|
-
|
|
518
|
-
memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
|
|
519
|
-
|
|
520
|
-
if (++i10 == ne00) {
|
|
521
|
-
i10 = 0;
|
|
522
|
-
if (++i11 == ne01) {
|
|
523
|
-
i11 = 0;
|
|
524
|
-
if (++i12 == ne02) {
|
|
525
|
-
i12 = 0;
|
|
526
|
-
if (++i13 == ne03) {
|
|
527
|
-
i13 = 0;
|
|
528
|
-
}
|
|
529
|
-
}
|
|
530
|
-
}
|
|
531
|
-
}
|
|
532
|
-
}
|
|
533
|
-
}
|
|
534
|
-
i10 += ne00 * (ne01 - ir1);
|
|
535
|
-
while (i10 >= ne0) {
|
|
536
|
-
i10 -= ne0;
|
|
537
|
-
if (++i11 == ne1) {
|
|
538
|
-
i11 = 0;
|
|
539
|
-
if (++i12 == ne2) {
|
|
540
|
-
i12 = 0;
|
|
541
|
-
if (++i13 == ne3) {
|
|
542
|
-
i13 = 0;
|
|
543
|
-
}
|
|
544
|
-
}
|
|
545
|
-
}
|
|
546
|
-
}
|
|
547
|
-
}
|
|
548
|
-
}
|
|
549
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
|
550
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
551
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
552
|
-
i10 += ne00 * ir0;
|
|
553
|
-
while (i10 >= ne0) {
|
|
554
|
-
i10 -= ne0;
|
|
555
|
-
if (++i11 == ne1) {
|
|
556
|
-
i11 = 0;
|
|
557
|
-
if (++i12 == ne2) {
|
|
558
|
-
i12 = 0;
|
|
559
|
-
if (++i13 == ne3) {
|
|
560
|
-
i13 = 0;
|
|
561
|
-
}
|
|
562
|
-
}
|
|
563
|
-
}
|
|
564
|
-
}
|
|
565
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
566
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
567
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
568
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
569
|
-
|
|
570
|
-
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
|
|
571
|
-
|
|
572
|
-
if (++i10 == ne0) {
|
|
573
|
-
i10 = 0;
|
|
574
|
-
if (++i11 == ne1) {
|
|
575
|
-
i11 = 0;
|
|
576
|
-
if (++i12 == ne2) {
|
|
577
|
-
i12 = 0;
|
|
578
|
-
if (++i13 == ne3) {
|
|
579
|
-
i13 = 0;
|
|
580
|
-
}
|
|
581
|
-
}
|
|
582
|
-
}
|
|
583
|
-
}
|
|
584
|
-
}
|
|
585
|
-
}
|
|
586
|
-
i10 += ne00 * (ne01 - ir1);
|
|
587
|
-
while (i10 >= ne0) {
|
|
588
|
-
i10 -= ne0;
|
|
589
|
-
if (++i11 == ne1) {
|
|
590
|
-
i11 = 0;
|
|
591
|
-
if (++i12 == ne2) {
|
|
592
|
-
i12 = 0;
|
|
593
|
-
if (++i13 == ne3) {
|
|
594
|
-
i13 = 0;
|
|
595
|
-
}
|
|
596
|
-
}
|
|
597
|
-
}
|
|
598
|
-
}
|
|
599
|
-
}
|
|
600
|
-
}
|
|
601
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
|
602
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
603
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
604
|
-
i10 += ne00 * ir0;
|
|
605
|
-
while (i10 >= ne0) {
|
|
606
|
-
i10 -= ne0;
|
|
607
|
-
if (++i11 == ne1) {
|
|
608
|
-
i11 = 0;
|
|
609
|
-
if (++i12 == ne2) {
|
|
610
|
-
i12 = 0;
|
|
611
|
-
if (++i13 == ne3) {
|
|
612
|
-
i13 = 0;
|
|
613
|
-
}
|
|
614
|
-
}
|
|
615
|
-
}
|
|
616
|
-
}
|
|
617
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
618
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
619
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
620
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
621
|
-
|
|
622
|
-
*(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
|
|
623
|
-
|
|
624
|
-
if (++i10 == ne0) {
|
|
625
|
-
i10 = 0;
|
|
626
|
-
if (++i11 == ne1) {
|
|
627
|
-
i11 = 0;
|
|
628
|
-
if (++i12 == ne2) {
|
|
629
|
-
i12 = 0;
|
|
630
|
-
if (++i13 == ne3) {
|
|
631
|
-
i13 = 0;
|
|
632
|
-
}
|
|
633
|
-
}
|
|
634
|
-
}
|
|
635
|
-
}
|
|
636
|
-
}
|
|
637
|
-
}
|
|
638
|
-
i10 += ne00 * (ne01 - ir1);
|
|
639
|
-
while (i10 >= ne0) {
|
|
640
|
-
i10 -= ne0;
|
|
641
|
-
if (++i11 == ne1) {
|
|
642
|
-
i11 = 0;
|
|
643
|
-
if (++i12 == ne2) {
|
|
644
|
-
i12 = 0;
|
|
645
|
-
if (++i13 == ne3) {
|
|
646
|
-
i13 = 0;
|
|
647
|
-
}
|
|
648
|
-
}
|
|
649
|
-
}
|
|
650
|
-
}
|
|
651
|
-
}
|
|
652
|
-
}
|
|
653
|
-
} else {
|
|
654
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
655
|
-
}
|
|
656
|
-
}
|
|
657
|
-
|
|
658
|
-
static void ggml_compute_forward_dup_f32(
|
|
659
|
-
const ggml_compute_params * params,
|
|
660
|
-
ggml_tensor * dst) {
|
|
661
|
-
|
|
662
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
663
|
-
|
|
664
|
-
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
665
|
-
|
|
666
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
|
667
|
-
|
|
668
|
-
const int ith = params->ith; // thread index
|
|
669
|
-
const int nth = params->nth; // number of threads
|
|
670
|
-
|
|
671
|
-
// parallelize by rows
|
|
672
|
-
const int nr = ne01;
|
|
673
|
-
// number of rows per thread
|
|
674
|
-
const int dr = (nr + nth - 1) / nth;
|
|
675
|
-
// row range for this thread
|
|
676
|
-
const int ir0 = dr * ith;
|
|
677
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
678
|
-
|
|
679
|
-
if (src0->type == dst->type &&
|
|
680
|
-
ne00 == ne0 &&
|
|
681
|
-
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
|
682
|
-
// copy by rows
|
|
683
|
-
const size_t rs = ne00*nb00;
|
|
684
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
685
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
686
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
687
|
-
memcpy(
|
|
688
|
-
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
|
689
|
-
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
|
690
|
-
rs);
|
|
691
|
-
}
|
|
692
|
-
}
|
|
693
|
-
}
|
|
694
|
-
return;
|
|
695
|
-
}
|
|
696
|
-
|
|
697
|
-
if (ggml_is_contiguous(dst)) {
|
|
698
|
-
// TODO: simplify
|
|
699
|
-
if (nb00 == sizeof(float)) {
|
|
700
|
-
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
701
|
-
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
702
|
-
|
|
703
|
-
size_t id = 0;
|
|
704
|
-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
705
|
-
char * dst_ptr = (char *) dst->data;
|
|
706
|
-
|
|
707
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
708
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
709
|
-
id += rs * ir0;
|
|
710
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
711
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
712
|
-
from_float(src0_ptr, dst_ptr + id, ne00);
|
|
713
|
-
id += rs;
|
|
714
|
-
}
|
|
715
|
-
id += rs * (ne01 - ir1);
|
|
716
|
-
}
|
|
717
|
-
}
|
|
718
|
-
} else {
|
|
719
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
720
|
-
}
|
|
721
|
-
} else {
|
|
722
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
723
|
-
|
|
724
|
-
if (dst->type == GGML_TYPE_F32) {
|
|
725
|
-
size_t id = 0;
|
|
726
|
-
float * dst_ptr = (float *) dst->data;
|
|
727
|
-
|
|
728
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
729
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
730
|
-
id += ne00 * ir0;
|
|
731
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
732
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
733
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
734
|
-
|
|
735
|
-
dst_ptr[id] = *src0_ptr;
|
|
736
|
-
id++;
|
|
737
|
-
}
|
|
738
|
-
}
|
|
739
|
-
id += ne00 * (ne01 - ir1);
|
|
740
|
-
}
|
|
741
|
-
}
|
|
742
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
|
743
|
-
size_t id = 0;
|
|
744
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
745
|
-
|
|
746
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
747
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
748
|
-
id += ne00 * ir0;
|
|
749
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
750
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
751
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
752
|
-
|
|
753
|
-
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
|
|
754
|
-
id++;
|
|
755
|
-
}
|
|
756
|
-
}
|
|
757
|
-
id += ne00 * (ne01 - ir1);
|
|
758
|
-
}
|
|
759
|
-
}
|
|
760
|
-
} else if (dst->type == GGML_TYPE_BF16) {
|
|
108
|
+
} else {
|
|
109
|
+
// casting between non-quantized types
|
|
761
110
|
size_t id = 0;
|
|
762
|
-
|
|
111
|
+
dst_t * dst_ptr = (dst_t *) dst->data;
|
|
763
112
|
|
|
764
113
|
for (int i03 = 0; i03 < ne03; i03++) {
|
|
765
114
|
for (int i02 = 0; i02 < ne02; i02++) {
|
|
766
115
|
id += ne00 * ir0;
|
|
767
116
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
117
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
768
118
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
|
|
119
|
+
float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
|
120
|
+
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
|
772
121
|
id++;
|
|
773
122
|
}
|
|
774
123
|
}
|
|
775
124
|
id += ne00 * (ne01 - ir1);
|
|
776
125
|
}
|
|
777
126
|
}
|
|
778
|
-
} else {
|
|
779
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
780
127
|
}
|
|
781
|
-
}
|
|
128
|
+
} else {
|
|
129
|
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
130
|
+
|
|
131
|
+
size_t id = 0;
|
|
132
|
+
dst_t * dst_ptr = (dst_t *) dst->data;
|
|
133
|
+
|
|
134
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
|
135
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
|
136
|
+
id += ne00 * ir0;
|
|
137
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
138
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
|
139
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
782
140
|
|
|
141
|
+
float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
|
|
142
|
+
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
|
143
|
+
id++;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
id += ne00 * (ne01 - ir1);
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
}
|
|
783
150
|
return;
|
|
784
151
|
}
|
|
785
152
|
|
|
786
153
|
// dst counters
|
|
787
|
-
|
|
788
154
|
int64_t i10 = 0;
|
|
789
155
|
int64_t i11 = 0;
|
|
790
156
|
int64_t i12 = 0;
|
|
791
157
|
int64_t i13 = 0;
|
|
792
158
|
|
|
793
|
-
if (
|
|
159
|
+
if constexpr (std::is_same_v<dst_t, src_t>) {
|
|
794
160
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
795
161
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
796
162
|
i10 += ne00 * ir0;
|
|
@@ -811,15 +177,15 @@ static void ggml_compute_forward_dup_f32(
|
|
|
811
177
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
812
178
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
813
179
|
|
|
814
|
-
memcpy(dst_ptr, src0_ptr, sizeof(
|
|
180
|
+
memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
|
|
815
181
|
|
|
816
|
-
if (++i10 ==
|
|
182
|
+
if (++i10 == ne00) {
|
|
817
183
|
i10 = 0;
|
|
818
|
-
if (++i11 ==
|
|
184
|
+
if (++i11 == ne01) {
|
|
819
185
|
i11 = 0;
|
|
820
|
-
if (++i12 ==
|
|
186
|
+
if (++i12 == ne02) {
|
|
821
187
|
i12 = 0;
|
|
822
|
-
if (++i13 ==
|
|
188
|
+
if (++i13 == ne03) {
|
|
823
189
|
i13 = 0;
|
|
824
190
|
}
|
|
825
191
|
}
|
|
@@ -842,7 +208,8 @@ static void ggml_compute_forward_dup_f32(
|
|
|
842
208
|
}
|
|
843
209
|
}
|
|
844
210
|
}
|
|
845
|
-
|
|
211
|
+
|
|
212
|
+
} else {
|
|
846
213
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
847
214
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
848
215
|
i10 += ne00 * ir0;
|
|
@@ -863,7 +230,8 @@ static void ggml_compute_forward_dup_f32(
|
|
|
863
230
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
864
231
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
865
232
|
|
|
866
|
-
|
|
233
|
+
float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
|
|
234
|
+
*(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
|
|
867
235
|
|
|
868
236
|
if (++i10 == ne0) {
|
|
869
237
|
i10 = 0;
|
|
@@ -894,60 +262,63 @@ static void ggml_compute_forward_dup_f32(
|
|
|
894
262
|
}
|
|
895
263
|
}
|
|
896
264
|
}
|
|
897
|
-
}
|
|
898
|
-
|
|
899
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
900
|
-
i10 += ne00 * ir0;
|
|
901
|
-
while (i10 >= ne0) {
|
|
902
|
-
i10 -= ne0;
|
|
903
|
-
if (++i11 == ne1) {
|
|
904
|
-
i11 = 0;
|
|
905
|
-
if (++i12 == ne2) {
|
|
906
|
-
i12 = 0;
|
|
907
|
-
if (++i13 == ne3) {
|
|
908
|
-
i13 = 0;
|
|
909
|
-
}
|
|
910
|
-
}
|
|
911
|
-
}
|
|
912
|
-
}
|
|
913
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
914
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
915
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
916
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
265
|
+
}
|
|
266
|
+
}
|
|
917
267
|
|
|
918
|
-
*(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
|
|
919
268
|
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
269
|
+
template<typename src_t>
|
|
270
|
+
static void ggml_compute_forward_dup_to_q(
|
|
271
|
+
const ggml_compute_params * params,
|
|
272
|
+
ggml_tensor * dst) {
|
|
273
|
+
|
|
274
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
275
|
+
|
|
276
|
+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
277
|
+
GGML_ASSERT(!ggml_is_quantized(src0->type));
|
|
278
|
+
|
|
279
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
280
|
+
|
|
281
|
+
const int ith = params->ith; // thread index
|
|
282
|
+
const int nth = params->nth; // number of threads
|
|
283
|
+
|
|
284
|
+
// parallelize by rows
|
|
285
|
+
const int nr = ne01;
|
|
286
|
+
// number of rows per thread
|
|
287
|
+
const int dr = (nr + nth - 1) / nth;
|
|
288
|
+
// row range for this thread
|
|
289
|
+
const int ir0 = dr * ith;
|
|
290
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
291
|
+
|
|
292
|
+
if (ggml_is_contiguous(dst) &&
|
|
293
|
+
nb00 == sizeof(src_t) &&
|
|
294
|
+
ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
295
|
+
// casting non-quantized types --> intermediate f32 --> quantized
|
|
296
|
+
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
297
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
298
|
+
|
|
299
|
+
size_t id = 0;
|
|
300
|
+
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
301
|
+
char * dst_ptr = (char *) dst->data;
|
|
302
|
+
|
|
303
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
|
304
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
|
305
|
+
id += rs * ir0;
|
|
306
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
307
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
308
|
+
|
|
309
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
|
310
|
+
src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
|
945
311
|
}
|
|
312
|
+
|
|
313
|
+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
314
|
+
id += rs;
|
|
946
315
|
}
|
|
316
|
+
id += rs * (ne01 - ir1);
|
|
947
317
|
}
|
|
948
318
|
}
|
|
949
319
|
} else {
|
|
950
|
-
|
|
320
|
+
// printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
|
|
321
|
+
GGML_ABORT("not implemented");
|
|
951
322
|
}
|
|
952
323
|
}
|
|
953
324
|
|
|
@@ -1101,7 +472,7 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
1101
472
|
}
|
|
1102
473
|
}
|
|
1103
474
|
|
|
1104
|
-
static void
|
|
475
|
+
static void ggml_compute_forward_dup_from_q(
|
|
1105
476
|
const ggml_compute_params * params,
|
|
1106
477
|
ggml_tensor * dst) {
|
|
1107
478
|
|
|
@@ -1166,20 +537,35 @@ void ggml_compute_forward_dup(
|
|
|
1166
537
|
switch (src0->type) {
|
|
1167
538
|
case GGML_TYPE_F16:
|
|
1168
539
|
{
|
|
1169
|
-
|
|
540
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
|
|
541
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
|
|
542
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
|
|
543
|
+
else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
|
|
1170
544
|
} break;
|
|
1171
545
|
case GGML_TYPE_BF16:
|
|
1172
546
|
{
|
|
1173
|
-
|
|
547
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
|
|
548
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
|
|
549
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
|
|
550
|
+
else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
|
|
1174
551
|
} break;
|
|
1175
552
|
case GGML_TYPE_F32:
|
|
1176
553
|
{
|
|
1177
|
-
|
|
554
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
|
|
555
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
|
|
556
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
|
|
557
|
+
else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
|
|
558
|
+
else ggml_compute_forward_dup_to_q<float>(params, dst);
|
|
559
|
+
} break;
|
|
560
|
+
case GGML_TYPE_I32:
|
|
561
|
+
{
|
|
562
|
+
if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
|
|
563
|
+
else GGML_ABORT("not implemented");
|
|
1178
564
|
} break;
|
|
1179
565
|
default:
|
|
1180
566
|
{
|
|
1181
567
|
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
|
|
1182
|
-
|
|
568
|
+
ggml_compute_forward_dup_from_q(params, dst);
|
|
1183
569
|
break;
|
|
1184
570
|
}
|
|
1185
571
|
GGML_ABORT("fatal error");
|
|
@@ -1283,6 +669,7 @@ void ggml_compute_forward_add(
|
|
|
1283
669
|
case GGML_TYPE_Q5_0:
|
|
1284
670
|
case GGML_TYPE_Q5_1:
|
|
1285
671
|
case GGML_TYPE_Q8_0:
|
|
672
|
+
case GGML_TYPE_MXFP4:
|
|
1286
673
|
case GGML_TYPE_Q2_K:
|
|
1287
674
|
case GGML_TYPE_Q3_K:
|
|
1288
675
|
case GGML_TYPE_Q4_K:
|
|
@@ -1309,6 +696,77 @@ void ggml_compute_forward_add(
|
|
|
1309
696
|
}
|
|
1310
697
|
}
|
|
1311
698
|
|
|
699
|
+
// ggml_compute_forward_add_id
|
|
700
|
+
|
|
701
|
+
static void ggml_compute_forward_add_id_f32(
|
|
702
|
+
const ggml_compute_params * params,
|
|
703
|
+
ggml_tensor * dst) {
|
|
704
|
+
|
|
705
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
706
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
707
|
+
const ggml_tensor * src2 = dst->src[2];
|
|
708
|
+
|
|
709
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
710
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
711
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
712
|
+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
|
713
|
+
|
|
714
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
715
|
+
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
716
|
+
|
|
717
|
+
const int ith = params->ith;
|
|
718
|
+
const int nth = params->nth;
|
|
719
|
+
|
|
720
|
+
const int nr = ggml_nrows(src0);
|
|
721
|
+
|
|
722
|
+
GGML_TENSOR_TERNARY_OP_LOCALS
|
|
723
|
+
|
|
724
|
+
GGML_ASSERT( nb0 == sizeof(float));
|
|
725
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
|
726
|
+
|
|
727
|
+
// rows per thread
|
|
728
|
+
const int dr = (nr + nth - 1)/nth;
|
|
729
|
+
|
|
730
|
+
// row range for this thread
|
|
731
|
+
const int ir0 = dr*ith;
|
|
732
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
733
|
+
|
|
734
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
735
|
+
// src0 indices
|
|
736
|
+
const int i3 = ir/(ne2*ne1);
|
|
737
|
+
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
|
738
|
+
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
|
739
|
+
|
|
740
|
+
// src1 indices
|
|
741
|
+
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
|
|
742
|
+
|
|
743
|
+
GGML_ASSERT(i11 >= 0 && i11 < ne11);
|
|
744
|
+
|
|
745
|
+
ggml_vec_add_f32(ne0,
|
|
746
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
|
747
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
|
748
|
+
(float *) ((char *) src1->data + i11*nb11));
|
|
749
|
+
}
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
void ggml_compute_forward_add_id(
|
|
753
|
+
const ggml_compute_params * params,
|
|
754
|
+
ggml_tensor * dst) {
|
|
755
|
+
|
|
756
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
757
|
+
|
|
758
|
+
switch (src0->type) {
|
|
759
|
+
case GGML_TYPE_F32:
|
|
760
|
+
{
|
|
761
|
+
ggml_compute_forward_add_id_f32(params, dst);
|
|
762
|
+
} break;
|
|
763
|
+
default:
|
|
764
|
+
{
|
|
765
|
+
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
|
|
766
|
+
}
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
|
|
1312
770
|
// ggml_compute_forward_add1
|
|
1313
771
|
|
|
1314
772
|
static void ggml_compute_forward_add1_f32(
|
|
@@ -1660,6 +1118,7 @@ void ggml_compute_forward_add1(
|
|
|
1660
1118
|
case GGML_TYPE_Q5_1:
|
|
1661
1119
|
case GGML_TYPE_Q8_0:
|
|
1662
1120
|
case GGML_TYPE_Q8_1:
|
|
1121
|
+
case GGML_TYPE_MXFP4:
|
|
1663
1122
|
case GGML_TYPE_Q2_K:
|
|
1664
1123
|
case GGML_TYPE_Q3_K:
|
|
1665
1124
|
case GGML_TYPE_Q4_K:
|
|
@@ -1787,6 +1246,7 @@ void ggml_compute_forward_acc(
|
|
|
1787
1246
|
case GGML_TYPE_Q5_1:
|
|
1788
1247
|
case GGML_TYPE_Q8_0:
|
|
1789
1248
|
case GGML_TYPE_Q8_1:
|
|
1249
|
+
case GGML_TYPE_MXFP4:
|
|
1790
1250
|
case GGML_TYPE_Q2_K:
|
|
1791
1251
|
case GGML_TYPE_Q3_K:
|
|
1792
1252
|
case GGML_TYPE_Q4_K:
|
|
@@ -1936,6 +1396,56 @@ void ggml_compute_forward_sum(
|
|
|
1936
1396
|
}
|
|
1937
1397
|
}
|
|
1938
1398
|
|
|
1399
|
+
// ggml_compute_forward_cumsum
|
|
1400
|
+
|
|
1401
|
+
static void ggml_compute_forward_cumsum_f32(
|
|
1402
|
+
const ggml_compute_params * params,
|
|
1403
|
+
ggml_tensor * dst) {
|
|
1404
|
+
|
|
1405
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1406
|
+
|
|
1407
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
1408
|
+
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
|
1409
|
+
|
|
1410
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
1411
|
+
|
|
1412
|
+
GGML_ASSERT(ne0 == ne00);
|
|
1413
|
+
GGML_ASSERT(ne1 == ne01);
|
|
1414
|
+
GGML_ASSERT(ne2 == ne02);
|
|
1415
|
+
GGML_ASSERT(ne3 == ne03);
|
|
1416
|
+
|
|
1417
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
1418
|
+
|
|
1419
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
1420
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
1421
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
1422
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
1423
|
+
|
|
1424
|
+
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
1425
|
+
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
1426
|
+
|
|
1427
|
+
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
|
1428
|
+
}
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
void ggml_compute_forward_cumsum(
|
|
1432
|
+
const ggml_compute_params * params,
|
|
1433
|
+
ggml_tensor * dst) {
|
|
1434
|
+
|
|
1435
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1436
|
+
|
|
1437
|
+
switch (src0->type) {
|
|
1438
|
+
case GGML_TYPE_F32:
|
|
1439
|
+
{
|
|
1440
|
+
ggml_compute_forward_cumsum_f32(params, dst);
|
|
1441
|
+
} break;
|
|
1442
|
+
default:
|
|
1443
|
+
{
|
|
1444
|
+
GGML_ABORT("fatal error");
|
|
1445
|
+
}
|
|
1446
|
+
}
|
|
1447
|
+
}
|
|
1448
|
+
|
|
1939
1449
|
// ggml_compute_forward_sum_rows
|
|
1940
1450
|
|
|
1941
1451
|
static void ggml_compute_forward_sum_rows_f32(
|
|
@@ -2656,24 +2166,101 @@ static void ggml_compute_forward_gelu_f16(
|
|
|
2656
2166
|
assert(!isnan(v));
|
|
2657
2167
|
assert(!isinf(v));
|
|
2658
2168
|
}
|
|
2659
|
-
#endif
|
|
2169
|
+
#endif
|
|
2170
|
+
}
|
|
2171
|
+
}
|
|
2172
|
+
|
|
2173
|
+
static void ggml_compute_forward_gelu(
|
|
2174
|
+
const ggml_compute_params * params,
|
|
2175
|
+
ggml_tensor * dst) {
|
|
2176
|
+
|
|
2177
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2178
|
+
|
|
2179
|
+
switch (src0->type) {
|
|
2180
|
+
case GGML_TYPE_F32:
|
|
2181
|
+
{
|
|
2182
|
+
ggml_compute_forward_gelu_f32(params, dst);
|
|
2183
|
+
} break;
|
|
2184
|
+
case GGML_TYPE_F16:
|
|
2185
|
+
{
|
|
2186
|
+
ggml_compute_forward_gelu_f16(params, dst);
|
|
2187
|
+
} break;
|
|
2188
|
+
default:
|
|
2189
|
+
{
|
|
2190
|
+
GGML_ABORT("fatal error");
|
|
2191
|
+
}
|
|
2192
|
+
}
|
|
2193
|
+
}
|
|
2194
|
+
|
|
2195
|
+
// ggml_compute_fill
|
|
2196
|
+
|
|
2197
|
+
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2198
|
+
const float c = ggml_get_op_params_f32(dst, 0);
|
|
2199
|
+
|
|
2200
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
|
2201
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
|
2202
|
+
|
|
2203
|
+
const auto [ir0, ir1] = get_thread_range(params, dst);
|
|
2204
|
+
|
|
2205
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2206
|
+
const int64_t i03 = ir/(ne2*ne1);
|
|
2207
|
+
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
|
2208
|
+
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
|
2209
|
+
|
|
2210
|
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2211
|
+
|
|
2212
|
+
ggml_vec_set_f32(ne0, dst_ptr, c);
|
|
2213
|
+
}
|
|
2214
|
+
}
|
|
2215
|
+
|
|
2216
|
+
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2217
|
+
ggml_compute_forward_fill_f32(params, dst);
|
|
2218
|
+
}
|
|
2219
|
+
|
|
2220
|
+
// ggml_compute_tri
|
|
2221
|
+
|
|
2222
|
+
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2223
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2224
|
+
|
|
2225
|
+
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
|
2226
|
+
|
|
2227
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2228
|
+
|
|
2229
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
2230
|
+
|
|
2231
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
2232
|
+
|
|
2233
|
+
bool (*bipred)(int, int);
|
|
2234
|
+
|
|
2235
|
+
switch (ttype) {
|
|
2236
|
+
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
|
2237
|
+
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
|
2238
|
+
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
|
2239
|
+
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
|
2240
|
+
default: GGML_ABORT("invalid tri type");
|
|
2241
|
+
}
|
|
2242
|
+
|
|
2243
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2244
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
2245
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
2246
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
2247
|
+
|
|
2248
|
+
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
2249
|
+
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2250
|
+
|
|
2251
|
+
for (int i0 = 0; i0 < ne0; ++i0) {
|
|
2252
|
+
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
|
2253
|
+
}
|
|
2660
2254
|
}
|
|
2661
2255
|
}
|
|
2662
2256
|
|
|
2663
|
-
|
|
2664
|
-
const ggml_compute_params * params,
|
|
2665
|
-
ggml_tensor * dst) {
|
|
2666
|
-
|
|
2257
|
+
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2667
2258
|
const ggml_tensor * src0 = dst->src[0];
|
|
2668
2259
|
|
|
2669
2260
|
switch (src0->type) {
|
|
2670
2261
|
case GGML_TYPE_F32:
|
|
2671
2262
|
{
|
|
2672
|
-
|
|
2673
|
-
} break;
|
|
2674
|
-
case GGML_TYPE_F16:
|
|
2675
|
-
{
|
|
2676
|
-
ggml_compute_forward_gelu_f16(params, dst);
|
|
2263
|
+
ggml_compute_forward_tri_f32(params, dst);
|
|
2677
2264
|
} break;
|
|
2678
2265
|
default:
|
|
2679
2266
|
{
|
|
@@ -3032,27 +2619,281 @@ static void ggml_compute_forward_leaky_relu_f16(
|
|
|
3032
2619
|
return;
|
|
3033
2620
|
}
|
|
3034
2621
|
|
|
3035
|
-
assert(ggml_is_contiguous_1(src0));
|
|
3036
|
-
assert(ggml_is_contiguous_1(dst));
|
|
3037
|
-
assert(ggml_are_same_shape(src0, dst));
|
|
2622
|
+
assert(ggml_is_contiguous_1(src0));
|
|
2623
|
+
assert(ggml_is_contiguous_1(dst));
|
|
2624
|
+
assert(ggml_are_same_shape(src0, dst));
|
|
2625
|
+
|
|
2626
|
+
const int n = ggml_nrows(src0);
|
|
2627
|
+
const int nc = src0->ne[0];
|
|
2628
|
+
|
|
2629
|
+
float negative_slope;
|
|
2630
|
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
|
2631
|
+
|
|
2632
|
+
assert(dst->nb[0] == sizeof(ggml_fp16_t));
|
|
2633
|
+
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
|
2634
|
+
|
|
2635
|
+
for (int i = 0; i < n; i++) {
|
|
2636
|
+
ggml_vec_leaky_relu_f16(nc,
|
|
2637
|
+
(ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
|
|
2638
|
+
(ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
|
|
2639
|
+
}
|
|
2640
|
+
}
|
|
2641
|
+
|
|
2642
|
+
void ggml_compute_forward_leaky_relu(
|
|
2643
|
+
const ggml_compute_params * params,
|
|
2644
|
+
ggml_tensor * dst) {
|
|
2645
|
+
|
|
2646
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2647
|
+
|
|
2648
|
+
switch (src0->type) {
|
|
2649
|
+
case GGML_TYPE_F32:
|
|
2650
|
+
{
|
|
2651
|
+
ggml_compute_forward_leaky_relu_f32(params, dst);
|
|
2652
|
+
} break;
|
|
2653
|
+
case GGML_TYPE_F16:
|
|
2654
|
+
{
|
|
2655
|
+
ggml_compute_forward_leaky_relu_f16(params, dst);
|
|
2656
|
+
} break;
|
|
2657
|
+
default:
|
|
2658
|
+
{
|
|
2659
|
+
GGML_ABORT("fatal error");
|
|
2660
|
+
}
|
|
2661
|
+
}
|
|
2662
|
+
}
|
|
2663
|
+
|
|
2664
|
+
// ggml_compute_forward_silu_back
|
|
2665
|
+
|
|
2666
|
+
static void ggml_compute_forward_silu_back_f32(
|
|
2667
|
+
const ggml_compute_params * params,
|
|
2668
|
+
ggml_tensor * dst) {
|
|
2669
|
+
|
|
2670
|
+
const ggml_tensor * grad = dst->src[0];
|
|
2671
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
2672
|
+
|
|
2673
|
+
assert(ggml_is_contiguous_1(grad));
|
|
2674
|
+
assert(ggml_is_contiguous_1(src1));
|
|
2675
|
+
assert(ggml_is_contiguous_1(dst));
|
|
2676
|
+
assert(ggml_are_same_shape(src1, dst));
|
|
2677
|
+
assert(ggml_are_same_shape(src1, grad));
|
|
2678
|
+
|
|
2679
|
+
const int ith = params->ith;
|
|
2680
|
+
const int nth = params->nth;
|
|
2681
|
+
|
|
2682
|
+
const int nc = src1->ne[0];
|
|
2683
|
+
const int nr = ggml_nrows(src1);
|
|
2684
|
+
|
|
2685
|
+
// rows per thread
|
|
2686
|
+
const int dr = (nr + nth - 1)/nth;
|
|
2687
|
+
|
|
2688
|
+
// row range for this thread
|
|
2689
|
+
const int ir0 = dr*ith;
|
|
2690
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
2691
|
+
|
|
2692
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
2693
|
+
ggml_vec_silu_backward_f32(nc,
|
|
2694
|
+
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
|
2695
|
+
(float *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
2696
|
+
(float *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
2697
|
+
|
|
2698
|
+
#ifndef NDEBUG
|
|
2699
|
+
for (int k = 0; k < nc; k++) {
|
|
2700
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2701
|
+
GGML_UNUSED(x);
|
|
2702
|
+
assert(!isnan(x));
|
|
2703
|
+
assert(!isinf(x));
|
|
2704
|
+
}
|
|
2705
|
+
#endif
|
|
2706
|
+
}
|
|
2707
|
+
}
|
|
2708
|
+
|
|
2709
|
+
static void ggml_compute_forward_silu_back_f16(
|
|
2710
|
+
const ggml_compute_params * params,
|
|
2711
|
+
ggml_tensor * dst) {
|
|
2712
|
+
|
|
2713
|
+
const ggml_tensor * grad = dst->src[0];
|
|
2714
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
2715
|
+
|
|
2716
|
+
assert(ggml_is_contiguous_1(grad));
|
|
2717
|
+
assert(ggml_is_contiguous_1(src1));
|
|
2718
|
+
assert(ggml_is_contiguous_1(dst));
|
|
2719
|
+
assert(ggml_are_same_shape(src1, dst));
|
|
2720
|
+
assert(ggml_are_same_shape(src1, grad));
|
|
2721
|
+
|
|
2722
|
+
const int ith = params->ith;
|
|
2723
|
+
const int nth = params->nth;
|
|
2724
|
+
|
|
2725
|
+
const int nc = src1->ne[0];
|
|
2726
|
+
const int nr = ggml_nrows(src1);
|
|
2727
|
+
|
|
2728
|
+
// rows per thread
|
|
2729
|
+
const int dr = (nr + nth - 1)/nth;
|
|
2730
|
+
|
|
2731
|
+
// row range for this thread
|
|
2732
|
+
const int ir0 = dr*ith;
|
|
2733
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
2734
|
+
|
|
2735
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
2736
|
+
ggml_vec_silu_backward_f16(nc,
|
|
2737
|
+
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
|
2738
|
+
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
2739
|
+
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
2740
|
+
|
|
2741
|
+
#ifndef NDEBUG
|
|
2742
|
+
for (int k = 0; k < nc; k++) {
|
|
2743
|
+
const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2744
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2745
|
+
GGML_UNUSED(v);
|
|
2746
|
+
assert(!isnan(v));
|
|
2747
|
+
assert(!isinf(v));
|
|
2748
|
+
}
|
|
2749
|
+
#endif
|
|
2750
|
+
}
|
|
2751
|
+
}
|
|
2752
|
+
|
|
2753
|
+
void ggml_compute_forward_silu_back(
|
|
2754
|
+
const ggml_compute_params * params,
|
|
2755
|
+
ggml_tensor * dst) {
|
|
2756
|
+
|
|
2757
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2758
|
+
|
|
2759
|
+
switch (src0->type) {
|
|
2760
|
+
case GGML_TYPE_F32:
|
|
2761
|
+
{
|
|
2762
|
+
ggml_compute_forward_silu_back_f32(params, dst);
|
|
2763
|
+
} break;
|
|
2764
|
+
case GGML_TYPE_F16:
|
|
2765
|
+
{
|
|
2766
|
+
ggml_compute_forward_silu_back_f16(params, dst);
|
|
2767
|
+
} break;
|
|
2768
|
+
default:
|
|
2769
|
+
{
|
|
2770
|
+
GGML_ABORT("fatal error");
|
|
2771
|
+
}
|
|
2772
|
+
}
|
|
2773
|
+
}
|
|
2774
|
+
|
|
2775
|
+
// ggml_compute_forward_reglu
|
|
2776
|
+
|
|
2777
|
+
static void ggml_compute_forward_reglu_f32(
|
|
2778
|
+
const ggml_compute_params * params,
|
|
2779
|
+
ggml_tensor * dst) {
|
|
2780
|
+
|
|
2781
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2782
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
2783
|
+
char * src0_d = (char *) src0->data;
|
|
2784
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
2785
|
+
const size_t src0_o = src0->nb[1];
|
|
2786
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
2787
|
+
|
|
2788
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2789
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
2790
|
+
|
|
2791
|
+
if (src1) {
|
|
2792
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
2793
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
2794
|
+
}
|
|
2795
|
+
|
|
2796
|
+
const int ith = params->ith;
|
|
2797
|
+
const int nth = params->nth;
|
|
2798
|
+
|
|
2799
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
2800
|
+
const int nr = ggml_nrows(src0);
|
|
2801
|
+
|
|
2802
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
2803
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
2804
|
+
|
|
2805
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
2806
|
+
|
|
2807
|
+
// rows per thread
|
|
2808
|
+
const int dr = (nr + nth - 1)/nth;
|
|
2809
|
+
|
|
2810
|
+
// row range for this thread
|
|
2811
|
+
const int ir0 = dr*ith;
|
|
2812
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
2813
|
+
|
|
2814
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
2815
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
2816
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
2817
|
+
|
|
2818
|
+
if (!src1) {
|
|
2819
|
+
src0_p += swapped ? nc : 0;
|
|
2820
|
+
src1_p += swapped ? 0 : nc;
|
|
2821
|
+
}
|
|
2822
|
+
|
|
2823
|
+
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
2824
|
+
|
|
2825
|
+
#ifndef NDEBUG
|
|
2826
|
+
for (int k = 0; k < nc; k++) {
|
|
2827
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2828
|
+
GGML_UNUSED(x);
|
|
2829
|
+
assert(!isnan(x));
|
|
2830
|
+
assert(!isinf(x));
|
|
2831
|
+
}
|
|
2832
|
+
#endif
|
|
2833
|
+
}
|
|
2834
|
+
}
|
|
2835
|
+
|
|
2836
|
+
static void ggml_compute_forward_reglu_f16(
|
|
2837
|
+
const ggml_compute_params * params,
|
|
2838
|
+
ggml_tensor * dst) {
|
|
2839
|
+
|
|
2840
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2841
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
2842
|
+
char * src0_d = (char *) src0->data;
|
|
2843
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
2844
|
+
const size_t src0_o = src0->nb[1];
|
|
2845
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
2846
|
+
|
|
2847
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2848
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
2849
|
+
|
|
2850
|
+
if (src1) {
|
|
2851
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
2852
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
2853
|
+
}
|
|
2854
|
+
|
|
2855
|
+
const int ith = params->ith;
|
|
2856
|
+
const int nth = params->nth;
|
|
2857
|
+
|
|
2858
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
2859
|
+
const int nr = ggml_nrows(src0);
|
|
2860
|
+
|
|
2861
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
2862
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
2863
|
+
|
|
2864
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
2865
|
+
|
|
2866
|
+
// rows per thread
|
|
2867
|
+
const int dr = (nr + nth - 1)/nth;
|
|
2868
|
+
|
|
2869
|
+
// row range for this thread
|
|
2870
|
+
const int ir0 = dr*ith;
|
|
2871
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3038
2872
|
|
|
3039
|
-
|
|
3040
|
-
|
|
2873
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
2874
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
2875
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3041
2876
|
|
|
3042
|
-
|
|
3043
|
-
|
|
2877
|
+
if (!src1) {
|
|
2878
|
+
src0_p += swapped ? nc : 0;
|
|
2879
|
+
src1_p += swapped ? 0 : nc;
|
|
2880
|
+
}
|
|
3044
2881
|
|
|
3045
|
-
|
|
3046
|
-
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
|
2882
|
+
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3047
2883
|
|
|
3048
|
-
|
|
3049
|
-
|
|
3050
|
-
|
|
3051
|
-
|
|
2884
|
+
#ifndef NDEBUG
|
|
2885
|
+
for (int k = 0; k < nc; k++) {
|
|
2886
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2887
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
2888
|
+
GGML_UNUSED(v);
|
|
2889
|
+
assert(!isnan(v));
|
|
2890
|
+
assert(!isinf(v));
|
|
2891
|
+
}
|
|
2892
|
+
#endif
|
|
3052
2893
|
}
|
|
3053
2894
|
}
|
|
3054
2895
|
|
|
3055
|
-
void
|
|
2896
|
+
static void ggml_compute_forward_reglu(
|
|
3056
2897
|
const ggml_compute_params * params,
|
|
3057
2898
|
ggml_tensor * dst) {
|
|
3058
2899
|
|
|
@@ -3061,11 +2902,11 @@ void ggml_compute_forward_leaky_relu(
|
|
|
3061
2902
|
switch (src0->type) {
|
|
3062
2903
|
case GGML_TYPE_F32:
|
|
3063
2904
|
{
|
|
3064
|
-
|
|
2905
|
+
ggml_compute_forward_reglu_f32(params, dst);
|
|
3065
2906
|
} break;
|
|
3066
2907
|
case GGML_TYPE_F16:
|
|
3067
2908
|
{
|
|
3068
|
-
|
|
2909
|
+
ggml_compute_forward_reglu_f16(params, dst);
|
|
3069
2910
|
} break;
|
|
3070
2911
|
default:
|
|
3071
2912
|
{
|
|
@@ -3074,26 +2915,37 @@ void ggml_compute_forward_leaky_relu(
|
|
|
3074
2915
|
}
|
|
3075
2916
|
}
|
|
3076
2917
|
|
|
3077
|
-
//
|
|
2918
|
+
// ggml_compute_forward_geglu
|
|
3078
2919
|
|
|
3079
|
-
static void
|
|
2920
|
+
static void ggml_compute_forward_geglu_f32(
|
|
3080
2921
|
const ggml_compute_params * params,
|
|
3081
2922
|
ggml_tensor * dst) {
|
|
3082
2923
|
|
|
3083
|
-
const ggml_tensor *
|
|
2924
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3084
2925
|
const ggml_tensor * src1 = dst->src[1];
|
|
2926
|
+
char * src0_d = (char *) src0->data;
|
|
2927
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
2928
|
+
const size_t src0_o = src0->nb[1];
|
|
2929
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3085
2930
|
|
|
3086
|
-
|
|
3087
|
-
|
|
3088
|
-
|
|
3089
|
-
|
|
3090
|
-
|
|
2931
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2932
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
2933
|
+
|
|
2934
|
+
if (src1) {
|
|
2935
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
2936
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
2937
|
+
}
|
|
3091
2938
|
|
|
3092
2939
|
const int ith = params->ith;
|
|
3093
2940
|
const int nth = params->nth;
|
|
3094
2941
|
|
|
3095
|
-
const int nc = src1->ne[0];
|
|
3096
|
-
const int nr = ggml_nrows(
|
|
2942
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
2943
|
+
const int nr = ggml_nrows(src0);
|
|
2944
|
+
|
|
2945
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
2946
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
2947
|
+
|
|
2948
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3097
2949
|
|
|
3098
2950
|
// rows per thread
|
|
3099
2951
|
const int dr = (nr + nth - 1)/nth;
|
|
@@ -3103,10 +2955,15 @@ static void ggml_compute_forward_silu_back_f32(
|
|
|
3103
2955
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
3104
2956
|
|
|
3105
2957
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3106
|
-
|
|
3107
|
-
|
|
3108
|
-
|
|
3109
|
-
|
|
2958
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
2959
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
2960
|
+
|
|
2961
|
+
if (!src1) {
|
|
2962
|
+
src0_p += swapped ? nc : 0;
|
|
2963
|
+
src1_p += swapped ? 0 : nc;
|
|
2964
|
+
}
|
|
2965
|
+
|
|
2966
|
+
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3110
2967
|
|
|
3111
2968
|
#ifndef NDEBUG
|
|
3112
2969
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3119,24 +2976,35 @@ static void ggml_compute_forward_silu_back_f32(
|
|
|
3119
2976
|
}
|
|
3120
2977
|
}
|
|
3121
2978
|
|
|
3122
|
-
static void
|
|
2979
|
+
static void ggml_compute_forward_geglu_f16(
|
|
3123
2980
|
const ggml_compute_params * params,
|
|
3124
2981
|
ggml_tensor * dst) {
|
|
3125
2982
|
|
|
3126
|
-
const ggml_tensor *
|
|
2983
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3127
2984
|
const ggml_tensor * src1 = dst->src[1];
|
|
2985
|
+
char * src0_d = (char *) src0->data;
|
|
2986
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
2987
|
+
const size_t src0_o = src0->nb[1];
|
|
2988
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3128
2989
|
|
|
3129
|
-
|
|
3130
|
-
|
|
3131
|
-
|
|
3132
|
-
|
|
3133
|
-
|
|
2990
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2991
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
2992
|
+
|
|
2993
|
+
if (src1) {
|
|
2994
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
2995
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
2996
|
+
}
|
|
3134
2997
|
|
|
3135
2998
|
const int ith = params->ith;
|
|
3136
2999
|
const int nth = params->nth;
|
|
3137
3000
|
|
|
3138
|
-
const int nc = src1->ne[0];
|
|
3139
|
-
const int nr = ggml_nrows(
|
|
3001
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3002
|
+
const int nr = ggml_nrows(src0);
|
|
3003
|
+
|
|
3004
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3005
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3006
|
+
|
|
3007
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3140
3008
|
|
|
3141
3009
|
// rows per thread
|
|
3142
3010
|
const int dr = (nr + nth - 1)/nth;
|
|
@@ -3146,24 +3014,29 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
3146
3014
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
3147
3015
|
|
|
3148
3016
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
3152
|
-
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
3017
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3018
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3153
3019
|
|
|
3154
|
-
|
|
3020
|
+
if (!src1) {
|
|
3021
|
+
src0_p += swapped ? nc : 0;
|
|
3022
|
+
src1_p += swapped ? 0 : nc;
|
|
3023
|
+
}
|
|
3024
|
+
|
|
3025
|
+
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3026
|
+
|
|
3027
|
+
#ifndef NDEBUG
|
|
3155
3028
|
for (int k = 0; k < nc; k++) {
|
|
3156
|
-
const
|
|
3157
|
-
const float v =
|
|
3029
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3030
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3158
3031
|
GGML_UNUSED(v);
|
|
3159
3032
|
assert(!isnan(v));
|
|
3160
3033
|
assert(!isinf(v));
|
|
3161
3034
|
}
|
|
3162
|
-
|
|
3035
|
+
#endif
|
|
3163
3036
|
}
|
|
3164
3037
|
}
|
|
3165
3038
|
|
|
3166
|
-
void
|
|
3039
|
+
static void ggml_compute_forward_geglu(
|
|
3167
3040
|
const ggml_compute_params * params,
|
|
3168
3041
|
ggml_tensor * dst) {
|
|
3169
3042
|
|
|
@@ -3172,11 +3045,11 @@ void ggml_compute_forward_silu_back(
|
|
|
3172
3045
|
switch (src0->type) {
|
|
3173
3046
|
case GGML_TYPE_F32:
|
|
3174
3047
|
{
|
|
3175
|
-
|
|
3048
|
+
ggml_compute_forward_geglu_f32(params, dst);
|
|
3176
3049
|
} break;
|
|
3177
3050
|
case GGML_TYPE_F16:
|
|
3178
3051
|
{
|
|
3179
|
-
|
|
3052
|
+
ggml_compute_forward_geglu_f16(params, dst);
|
|
3180
3053
|
} break;
|
|
3181
3054
|
default:
|
|
3182
3055
|
{
|
|
@@ -3185,9 +3058,9 @@ void ggml_compute_forward_silu_back(
|
|
|
3185
3058
|
}
|
|
3186
3059
|
}
|
|
3187
3060
|
|
|
3188
|
-
//
|
|
3061
|
+
// ggml_compute_forward_swiglu
|
|
3189
3062
|
|
|
3190
|
-
static void
|
|
3063
|
+
static void ggml_compute_forward_swiglu_f32(
|
|
3191
3064
|
const ggml_compute_params * params,
|
|
3192
3065
|
ggml_tensor * dst) {
|
|
3193
3066
|
|
|
@@ -3233,7 +3106,7 @@ static void ggml_compute_forward_reglu_f32(
|
|
|
3233
3106
|
src1_p += swapped ? 0 : nc;
|
|
3234
3107
|
}
|
|
3235
3108
|
|
|
3236
|
-
|
|
3109
|
+
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3237
3110
|
|
|
3238
3111
|
#ifndef NDEBUG
|
|
3239
3112
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3246,7 +3119,7 @@ static void ggml_compute_forward_reglu_f32(
|
|
|
3246
3119
|
}
|
|
3247
3120
|
}
|
|
3248
3121
|
|
|
3249
|
-
static void
|
|
3122
|
+
static void ggml_compute_forward_swiglu_f16(
|
|
3250
3123
|
const ggml_compute_params * params,
|
|
3251
3124
|
ggml_tensor * dst) {
|
|
3252
3125
|
|
|
@@ -3292,7 +3165,7 @@ static void ggml_compute_forward_reglu_f16(
|
|
|
3292
3165
|
src1_p += swapped ? 0 : nc;
|
|
3293
3166
|
}
|
|
3294
3167
|
|
|
3295
|
-
|
|
3168
|
+
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3296
3169
|
|
|
3297
3170
|
#ifndef NDEBUG
|
|
3298
3171
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3306,7 +3179,7 @@ static void ggml_compute_forward_reglu_f16(
|
|
|
3306
3179
|
}
|
|
3307
3180
|
}
|
|
3308
3181
|
|
|
3309
|
-
static void
|
|
3182
|
+
static void ggml_compute_forward_swiglu(
|
|
3310
3183
|
const ggml_compute_params * params,
|
|
3311
3184
|
ggml_tensor * dst) {
|
|
3312
3185
|
|
|
@@ -3315,11 +3188,11 @@ static void ggml_compute_forward_reglu(
|
|
|
3315
3188
|
switch (src0->type) {
|
|
3316
3189
|
case GGML_TYPE_F32:
|
|
3317
3190
|
{
|
|
3318
|
-
|
|
3191
|
+
ggml_compute_forward_swiglu_f32(params, dst);
|
|
3319
3192
|
} break;
|
|
3320
3193
|
case GGML_TYPE_F16:
|
|
3321
3194
|
{
|
|
3322
|
-
|
|
3195
|
+
ggml_compute_forward_swiglu_f16(params, dst);
|
|
3323
3196
|
} break;
|
|
3324
3197
|
default:
|
|
3325
3198
|
{
|
|
@@ -3328,9 +3201,9 @@ static void ggml_compute_forward_reglu(
|
|
|
3328
3201
|
}
|
|
3329
3202
|
}
|
|
3330
3203
|
|
|
3331
|
-
//
|
|
3204
|
+
// ggml_compute_forward_swiglu_oai
|
|
3332
3205
|
|
|
3333
|
-
static void
|
|
3206
|
+
static void ggml_compute_forward_swiglu_oai_f32(
|
|
3334
3207
|
const ggml_compute_params * params,
|
|
3335
3208
|
ggml_tensor * dst) {
|
|
3336
3209
|
|
|
@@ -3359,6 +3232,8 @@ static void ggml_compute_forward_geglu_f32(
|
|
|
3359
3232
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3360
3233
|
|
|
3361
3234
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3235
|
+
const float alpha = ggml_get_op_params_f32(dst, 2);
|
|
3236
|
+
const float limit = ggml_get_op_params_f32(dst, 3);
|
|
3362
3237
|
|
|
3363
3238
|
// rows per thread
|
|
3364
3239
|
const int dr = (nr + nth - 1)/nth;
|
|
@@ -3370,13 +3245,98 @@ static void ggml_compute_forward_geglu_f32(
|
|
|
3370
3245
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3371
3246
|
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3372
3247
|
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3248
|
+
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
|
|
3373
3249
|
|
|
3374
3250
|
if (!src1) {
|
|
3375
3251
|
src0_p += swapped ? nc : 0;
|
|
3376
3252
|
src1_p += swapped ? 0 : nc;
|
|
3377
3253
|
}
|
|
3378
3254
|
|
|
3379
|
-
|
|
3255
|
+
for (int k = 0; k < nc; k++) {
|
|
3256
|
+
const float x = std::min(src0_p[k], limit);
|
|
3257
|
+
const float y = std::clamp(src1_p[k], -limit, limit);
|
|
3258
|
+
const float out_glu = x / (1.f + expf(alpha * (-x)));
|
|
3259
|
+
dst_p[k] = out_glu * (y + 1.f);
|
|
3260
|
+
}
|
|
3261
|
+
|
|
3262
|
+
#ifndef NDEBUG
|
|
3263
|
+
for (int k = 0; k < nc; k++) {
|
|
3264
|
+
const float x = dst_p[k];
|
|
3265
|
+
GGML_UNUSED(x);
|
|
3266
|
+
assert(!isnan(x));
|
|
3267
|
+
assert(!isinf(x));
|
|
3268
|
+
}
|
|
3269
|
+
#endif
|
|
3270
|
+
}
|
|
3271
|
+
}
|
|
3272
|
+
|
|
3273
|
+
static void ggml_compute_forward_swiglu_oai(
|
|
3274
|
+
const ggml_compute_params * params,
|
|
3275
|
+
ggml_tensor * dst) {
|
|
3276
|
+
|
|
3277
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3278
|
+
|
|
3279
|
+
switch (src0->type) {
|
|
3280
|
+
case GGML_TYPE_F32:
|
|
3281
|
+
{
|
|
3282
|
+
ggml_compute_forward_swiglu_oai_f32(params, dst);
|
|
3283
|
+
} break;
|
|
3284
|
+
default:
|
|
3285
|
+
{
|
|
3286
|
+
GGML_ABORT("fatal error");
|
|
3287
|
+
}
|
|
3288
|
+
}
|
|
3289
|
+
}
|
|
3290
|
+
|
|
3291
|
+
// ggml_compute_forward_geglu_erf
|
|
3292
|
+
|
|
3293
|
+
static void ggml_compute_forward_geglu_erf_f32(
|
|
3294
|
+
const ggml_compute_params * params,
|
|
3295
|
+
ggml_tensor * dst) {
|
|
3296
|
+
|
|
3297
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3298
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3299
|
+
char * src0_d = (char *) src0->data;
|
|
3300
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3301
|
+
const size_t src0_o = src0->nb[1];
|
|
3302
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3303
|
+
|
|
3304
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3305
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3306
|
+
|
|
3307
|
+
if (src1) {
|
|
3308
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3309
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3310
|
+
}
|
|
3311
|
+
|
|
3312
|
+
const int ith = params->ith;
|
|
3313
|
+
const int nth = params->nth;
|
|
3314
|
+
|
|
3315
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3316
|
+
const int nr = ggml_nrows(src0);
|
|
3317
|
+
|
|
3318
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3319
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3320
|
+
|
|
3321
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3322
|
+
|
|
3323
|
+
// rows per thread
|
|
3324
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3325
|
+
|
|
3326
|
+
// row range for this thread
|
|
3327
|
+
const int ir0 = dr*ith;
|
|
3328
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3329
|
+
|
|
3330
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3331
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3332
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3333
|
+
|
|
3334
|
+
if (!src1) {
|
|
3335
|
+
src0_p += swapped ? nc : 0;
|
|
3336
|
+
src1_p += swapped ? 0 : nc;
|
|
3337
|
+
}
|
|
3338
|
+
|
|
3339
|
+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3380
3340
|
|
|
3381
3341
|
#ifndef NDEBUG
|
|
3382
3342
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3389,7 +3349,7 @@ static void ggml_compute_forward_geglu_f32(
|
|
|
3389
3349
|
}
|
|
3390
3350
|
}
|
|
3391
3351
|
|
|
3392
|
-
static void
|
|
3352
|
+
static void ggml_compute_forward_geglu_erf_f16(
|
|
3393
3353
|
const ggml_compute_params * params,
|
|
3394
3354
|
ggml_tensor * dst) {
|
|
3395
3355
|
|
|
@@ -3435,7 +3395,7 @@ static void ggml_compute_forward_geglu_f16(
|
|
|
3435
3395
|
src1_p += swapped ? 0 : nc;
|
|
3436
3396
|
}
|
|
3437
3397
|
|
|
3438
|
-
|
|
3398
|
+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3439
3399
|
|
|
3440
3400
|
#ifndef NDEBUG
|
|
3441
3401
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3449,7 +3409,7 @@ static void ggml_compute_forward_geglu_f16(
|
|
|
3449
3409
|
}
|
|
3450
3410
|
}
|
|
3451
3411
|
|
|
3452
|
-
static void
|
|
3412
|
+
static void ggml_compute_forward_geglu_erf(
|
|
3453
3413
|
const ggml_compute_params * params,
|
|
3454
3414
|
ggml_tensor * dst) {
|
|
3455
3415
|
|
|
@@ -3458,11 +3418,11 @@ static void ggml_compute_forward_geglu(
|
|
|
3458
3418
|
switch (src0->type) {
|
|
3459
3419
|
case GGML_TYPE_F32:
|
|
3460
3420
|
{
|
|
3461
|
-
|
|
3421
|
+
ggml_compute_forward_geglu_erf_f32(params, dst);
|
|
3462
3422
|
} break;
|
|
3463
3423
|
case GGML_TYPE_F16:
|
|
3464
3424
|
{
|
|
3465
|
-
|
|
3425
|
+
ggml_compute_forward_geglu_erf_f16(params, dst);
|
|
3466
3426
|
} break;
|
|
3467
3427
|
default:
|
|
3468
3428
|
{
|
|
@@ -3471,9 +3431,9 @@ static void ggml_compute_forward_geglu(
|
|
|
3471
3431
|
}
|
|
3472
3432
|
}
|
|
3473
3433
|
|
|
3474
|
-
//
|
|
3434
|
+
// ggml_compute_forward_geglu_quick
|
|
3475
3435
|
|
|
3476
|
-
static void
|
|
3436
|
+
static void ggml_compute_forward_geglu_quick_f32(
|
|
3477
3437
|
const ggml_compute_params * params,
|
|
3478
3438
|
ggml_tensor * dst) {
|
|
3479
3439
|
|
|
@@ -3519,7 +3479,7 @@ static void ggml_compute_forward_swiglu_f32(
|
|
|
3519
3479
|
src1_p += swapped ? 0 : nc;
|
|
3520
3480
|
}
|
|
3521
3481
|
|
|
3522
|
-
|
|
3482
|
+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3523
3483
|
|
|
3524
3484
|
#ifndef NDEBUG
|
|
3525
3485
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3532,7 +3492,7 @@ static void ggml_compute_forward_swiglu_f32(
|
|
|
3532
3492
|
}
|
|
3533
3493
|
}
|
|
3534
3494
|
|
|
3535
|
-
static void
|
|
3495
|
+
static void ggml_compute_forward_geglu_quick_f16(
|
|
3536
3496
|
const ggml_compute_params * params,
|
|
3537
3497
|
ggml_tensor * dst) {
|
|
3538
3498
|
|
|
@@ -3578,7 +3538,7 @@ static void ggml_compute_forward_swiglu_f16(
|
|
|
3578
3538
|
src1_p += swapped ? 0 : nc;
|
|
3579
3539
|
}
|
|
3580
3540
|
|
|
3581
|
-
|
|
3541
|
+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3582
3542
|
|
|
3583
3543
|
#ifndef NDEBUG
|
|
3584
3544
|
for (int k = 0; k < nc; k++) {
|
|
@@ -3592,7 +3552,7 @@ static void ggml_compute_forward_swiglu_f16(
|
|
|
3592
3552
|
}
|
|
3593
3553
|
}
|
|
3594
3554
|
|
|
3595
|
-
static void
|
|
3555
|
+
static void ggml_compute_forward_geglu_quick(
|
|
3596
3556
|
const ggml_compute_params * params,
|
|
3597
3557
|
ggml_tensor * dst) {
|
|
3598
3558
|
|
|
@@ -3601,11 +3561,11 @@ static void ggml_compute_forward_swiglu(
|
|
|
3601
3561
|
switch (src0->type) {
|
|
3602
3562
|
case GGML_TYPE_F32:
|
|
3603
3563
|
{
|
|
3604
|
-
|
|
3564
|
+
ggml_compute_forward_geglu_quick_f32(params, dst);
|
|
3605
3565
|
} break;
|
|
3606
3566
|
case GGML_TYPE_F16:
|
|
3607
3567
|
{
|
|
3608
|
-
|
|
3568
|
+
ggml_compute_forward_geglu_quick_f16(params, dst);
|
|
3609
3569
|
} break;
|
|
3610
3570
|
default:
|
|
3611
3571
|
{
|
|
@@ -3636,31 +3596,27 @@ static void ggml_compute_forward_norm_f32(
|
|
|
3636
3596
|
|
|
3637
3597
|
GGML_ASSERT(eps >= 0.0f);
|
|
3638
3598
|
|
|
3639
|
-
// TODO: optimize
|
|
3640
3599
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
3641
3600
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
3642
3601
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
3643
3602
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
3644
3603
|
|
|
3645
|
-
|
|
3646
|
-
|
|
3647
|
-
sum += (ggml_float)x[i00];
|
|
3648
|
-
}
|
|
3649
|
-
|
|
3604
|
+
float sum = 0.0;
|
|
3605
|
+
ggml_vec_sum_f32(ne00, &sum, x);
|
|
3650
3606
|
float mean = sum/ne00;
|
|
3651
3607
|
|
|
3652
3608
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3609
|
+
float variance = 0;
|
|
3653
3610
|
|
|
3654
|
-
|
|
3655
|
-
|
|
3656
|
-
|
|
3657
|
-
|
|
3658
|
-
|
|
3659
|
-
|
|
3611
|
+
#ifdef GGML_USE_ACCELERATE
|
|
3612
|
+
mean = -mean;
|
|
3613
|
+
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
|
3614
|
+
vDSP_measqv(y, 1, &variance, ne00);
|
|
3615
|
+
#else
|
|
3616
|
+
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
|
3617
|
+
#endif //GGML_USE_ACCELERATE
|
|
3660
3618
|
|
|
3661
|
-
float variance = sum2/ne00;
|
|
3662
3619
|
const float scale = 1.0f/sqrtf(variance + eps);
|
|
3663
|
-
|
|
3664
3620
|
ggml_vec_scale_f32(ne00, y, scale);
|
|
3665
3621
|
}
|
|
3666
3622
|
}
|
|
@@ -3729,6 +3685,9 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
|
3729
3685
|
|
|
3730
3686
|
const float scale = 1.0f/sqrtf(mean + eps);
|
|
3731
3687
|
|
|
3688
|
+
// if you hit this, likely you got an inf somewhere earlier
|
|
3689
|
+
assert(scale > 0.0f);
|
|
3690
|
+
|
|
3732
3691
|
ggml_vec_scale_f32(ne00, y, scale);
|
|
3733
3692
|
}
|
|
3734
3693
|
}
|
|
@@ -4310,6 +4269,7 @@ void ggml_compute_forward_out_prod(
|
|
|
4310
4269
|
case GGML_TYPE_Q5_0:
|
|
4311
4270
|
case GGML_TYPE_Q5_1:
|
|
4312
4271
|
case GGML_TYPE_Q8_0:
|
|
4272
|
+
case GGML_TYPE_MXFP4:
|
|
4313
4273
|
case GGML_TYPE_Q2_K:
|
|
4314
4274
|
case GGML_TYPE_Q3_K:
|
|
4315
4275
|
case GGML_TYPE_Q4_K:
|
|
@@ -4357,9 +4317,11 @@ static void ggml_compute_forward_scale_f32(
|
|
|
4357
4317
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
4358
4318
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
4359
4319
|
|
|
4360
|
-
// scale factor
|
|
4361
|
-
float
|
|
4362
|
-
|
|
4320
|
+
float s; // scale factor
|
|
4321
|
+
float b; // bias
|
|
4322
|
+
|
|
4323
|
+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
|
|
4324
|
+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
|
|
4363
4325
|
|
|
4364
4326
|
const int ith = params->ith;
|
|
4365
4327
|
const int nth = params->nth;
|
|
@@ -4378,12 +4340,22 @@ static void ggml_compute_forward_scale_f32(
|
|
|
4378
4340
|
|
|
4379
4341
|
const size_t nb1 = dst->nb[1];
|
|
4380
4342
|
|
|
4381
|
-
|
|
4382
|
-
|
|
4383
|
-
|
|
4384
|
-
|
|
4343
|
+
if (b == 0.0f) {
|
|
4344
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4345
|
+
if (dst->data != src0->data) {
|
|
4346
|
+
// src0 is same shape as dst => same indices
|
|
4347
|
+
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
|
|
4348
|
+
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
|
4349
|
+
}
|
|
4350
|
+
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
|
4351
|
+
}
|
|
4352
|
+
} else {
|
|
4353
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4354
|
+
ggml_vec_mad1_f32(nc,
|
|
4355
|
+
(float *) ((char *) dst->data + i1*nb1),
|
|
4356
|
+
(float *) ((char *) src0->data + i1*nb1),
|
|
4357
|
+
s, b);
|
|
4385
4358
|
}
|
|
4386
|
-
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
|
4387
4359
|
}
|
|
4388
4360
|
}
|
|
4389
4361
|
|
|
@@ -4572,6 +4544,7 @@ void ggml_compute_forward_set(
|
|
|
4572
4544
|
case GGML_TYPE_Q5_1:
|
|
4573
4545
|
case GGML_TYPE_Q8_0:
|
|
4574
4546
|
case GGML_TYPE_Q8_1:
|
|
4547
|
+
case GGML_TYPE_MXFP4:
|
|
4575
4548
|
case GGML_TYPE_Q2_K:
|
|
4576
4549
|
case GGML_TYPE_Q3_K:
|
|
4577
4550
|
case GGML_TYPE_Q4_K:
|
|
@@ -4611,46 +4584,6 @@ void ggml_compute_forward_cont(
|
|
|
4611
4584
|
ggml_compute_forward_dup(params, dst);
|
|
4612
4585
|
}
|
|
4613
4586
|
|
|
4614
|
-
// ggml_compute_forward_reshape
|
|
4615
|
-
|
|
4616
|
-
void ggml_compute_forward_reshape(
|
|
4617
|
-
const ggml_compute_params * params,
|
|
4618
|
-
ggml_tensor * dst) {
|
|
4619
|
-
// NOP
|
|
4620
|
-
GGML_UNUSED(params);
|
|
4621
|
-
GGML_UNUSED(dst);
|
|
4622
|
-
}
|
|
4623
|
-
|
|
4624
|
-
// ggml_compute_forward_view
|
|
4625
|
-
|
|
4626
|
-
void ggml_compute_forward_view(
|
|
4627
|
-
const ggml_compute_params * params,
|
|
4628
|
-
ggml_tensor * dst) {
|
|
4629
|
-
// NOP
|
|
4630
|
-
GGML_UNUSED(params);
|
|
4631
|
-
GGML_UNUSED(dst);
|
|
4632
|
-
}
|
|
4633
|
-
|
|
4634
|
-
// ggml_compute_forward_permute
|
|
4635
|
-
|
|
4636
|
-
void ggml_compute_forward_permute(
|
|
4637
|
-
const ggml_compute_params * params,
|
|
4638
|
-
ggml_tensor * dst) {
|
|
4639
|
-
// NOP
|
|
4640
|
-
GGML_UNUSED(params);
|
|
4641
|
-
GGML_UNUSED(dst);
|
|
4642
|
-
}
|
|
4643
|
-
|
|
4644
|
-
// ggml_compute_forward_transpose
|
|
4645
|
-
|
|
4646
|
-
void ggml_compute_forward_transpose(
|
|
4647
|
-
const ggml_compute_params * params,
|
|
4648
|
-
ggml_tensor * dst) {
|
|
4649
|
-
// NOP
|
|
4650
|
-
GGML_UNUSED(params);
|
|
4651
|
-
GGML_UNUSED(dst);
|
|
4652
|
-
}
|
|
4653
|
-
|
|
4654
4587
|
// ggml_compute_forward_get_rows
|
|
4655
4588
|
|
|
4656
4589
|
static void ggml_compute_forward_get_rows_q(
|
|
@@ -4833,6 +4766,7 @@ void ggml_compute_forward_get_rows(
|
|
|
4833
4766
|
case GGML_TYPE_Q5_1:
|
|
4834
4767
|
case GGML_TYPE_Q8_0:
|
|
4835
4768
|
case GGML_TYPE_Q8_1:
|
|
4769
|
+
case GGML_TYPE_MXFP4:
|
|
4836
4770
|
case GGML_TYPE_Q2_K:
|
|
4837
4771
|
case GGML_TYPE_Q3_K:
|
|
4838
4772
|
case GGML_TYPE_Q4_K:
|
|
@@ -4890,6 +4824,7 @@ void ggml_compute_forward_get_rows(
|
|
|
4890
4824
|
//}
|
|
4891
4825
|
}
|
|
4892
4826
|
|
|
4827
|
+
template<typename idx_t>
|
|
4893
4828
|
static void ggml_compute_forward_set_rows_f32(
|
|
4894
4829
|
const ggml_compute_params * params,
|
|
4895
4830
|
ggml_tensor * dst) {
|
|
@@ -4928,7 +4863,7 @@ static void ggml_compute_forward_set_rows_f32(
|
|
|
4928
4863
|
const int64_t i11 = i02%ne11;
|
|
4929
4864
|
const int64_t i10 = i;
|
|
4930
4865
|
|
|
4931
|
-
const int64_t i1 = *(
|
|
4866
|
+
const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
|
4932
4867
|
|
|
4933
4868
|
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
|
4934
4869
|
|
|
@@ -4945,11 +4880,18 @@ void ggml_compute_forward_set_rows(
|
|
|
4945
4880
|
ggml_tensor * dst) {
|
|
4946
4881
|
|
|
4947
4882
|
const ggml_tensor * src0 = dst->src[0];
|
|
4883
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
4948
4884
|
|
|
4949
4885
|
switch (src0->type) {
|
|
4950
4886
|
case GGML_TYPE_F32:
|
|
4951
4887
|
{
|
|
4952
|
-
|
|
4888
|
+
if (src1->type == GGML_TYPE_I64) {
|
|
4889
|
+
ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
|
|
4890
|
+
} else if (src1->type == GGML_TYPE_I32) {
|
|
4891
|
+
ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
|
|
4892
|
+
} else {
|
|
4893
|
+
GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
|
|
4894
|
+
}
|
|
4953
4895
|
} break;
|
|
4954
4896
|
default:
|
|
4955
4897
|
{
|
|
@@ -5222,6 +5164,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5222
5164
|
|
|
5223
5165
|
const ggml_tensor * src0 = dst->src[0];
|
|
5224
5166
|
const ggml_tensor * src1 = dst->src[1];
|
|
5167
|
+
const ggml_tensor * src2 = dst->src[2];
|
|
5225
5168
|
|
|
5226
5169
|
assert(ggml_is_contiguous(dst));
|
|
5227
5170
|
assert(ggml_are_same_shape(src0, dst));
|
|
@@ -5232,14 +5175,17 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5232
5175
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
5233
5176
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
5234
5177
|
|
|
5235
|
-
// TODO: handle transposed/permuted matrices
|
|
5236
|
-
|
|
5237
5178
|
const int ith = params->ith;
|
|
5238
5179
|
const int nth = params->nth;
|
|
5239
5180
|
|
|
5240
5181
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
5241
5182
|
|
|
5242
|
-
|
|
5183
|
+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
|
5184
|
+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
|
5185
|
+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
|
5186
|
+
|
|
5187
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
|
5188
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
|
5243
5189
|
|
|
5244
5190
|
// TODO: is this supposed to be ceil instead of floor?
|
|
5245
5191
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
|
@@ -5249,68 +5195,78 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5249
5195
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
5250
5196
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
5251
5197
|
|
|
5252
|
-
|
|
5253
|
-
const int nr = ggml_nrows(src0);
|
|
5254
|
-
|
|
5255
|
-
// rows per thread
|
|
5256
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5257
|
-
|
|
5258
|
-
// row range for this thread
|
|
5259
|
-
const int ir0 = dr*ith;
|
|
5260
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5261
|
-
|
|
5262
|
-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
|
5198
|
+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
5263
5199
|
|
|
5264
5200
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
5265
5201
|
|
|
5266
|
-
|
|
5267
|
-
|
|
5268
|
-
const uint32_t h = (i1/ne01)%ne02; // head
|
|
5269
|
-
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
5270
|
-
|
|
5271
|
-
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
|
5272
|
-
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
5202
|
+
// sinks
|
|
5203
|
+
const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
|
|
5273
5204
|
|
|
5274
|
-
|
|
5275
|
-
|
|
5276
|
-
|
|
5277
|
-
|
|
5278
|
-
|
|
5279
|
-
|
|
5280
|
-
|
|
5281
|
-
|
|
5282
|
-
|
|
5283
|
-
|
|
5284
|
-
|
|
5285
|
-
|
|
5286
|
-
|
|
5287
|
-
|
|
5205
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
5206
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
5207
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
5208
|
+
const int64_t i11 = i01;
|
|
5209
|
+
const int64_t i12 = i02%ne12;
|
|
5210
|
+
const int64_t i13 = i03%ne13;
|
|
5211
|
+
|
|
5212
|
+
// ALiBi
|
|
5213
|
+
const uint32_t h = i02; // head
|
|
5214
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
5215
|
+
|
|
5216
|
+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
5217
|
+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
5218
|
+
|
|
5219
|
+
// broadcast the mask across rows
|
|
5220
|
+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5221
|
+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5222
|
+
|
|
5223
|
+
ggml_vec_cpy_f32 (ne00, wp, sp);
|
|
5224
|
+
ggml_vec_scale_f32(ne00, wp, scale);
|
|
5225
|
+
if (mp_f32) {
|
|
5226
|
+
if (use_f16) {
|
|
5227
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5228
|
+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
|
5229
|
+
}
|
|
5230
|
+
} else {
|
|
5231
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5232
|
+
wp[i] += slope*mp_f32[i];
|
|
5233
|
+
}
|
|
5234
|
+
}
|
|
5288
5235
|
}
|
|
5289
|
-
}
|
|
5290
|
-
}
|
|
5291
5236
|
|
|
5292
5237
|
#ifndef NDEBUG
|
|
5293
|
-
|
|
5294
|
-
|
|
5295
|
-
|
|
5296
|
-
|
|
5238
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5239
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
|
5240
|
+
assert(!isnan(wp[i]));
|
|
5241
|
+
}
|
|
5297
5242
|
#endif
|
|
5298
5243
|
|
|
5299
|
-
|
|
5300
|
-
|
|
5244
|
+
float max = -INFINITY;
|
|
5245
|
+
ggml_vec_max_f32(ne00, &max, wp);
|
|
5301
5246
|
|
|
5302
|
-
|
|
5303
|
-
|
|
5247
|
+
// if we have sinks, make a correction as if they were included in the softmax
|
|
5248
|
+
if (sk) {
|
|
5249
|
+
max = MAX(max, sk[i02]);
|
|
5250
|
+
}
|
|
5304
5251
|
|
|
5305
|
-
|
|
5306
|
-
|
|
5252
|
+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
|
5253
|
+
assert(sum > 0.0);
|
|
5254
|
+
|
|
5255
|
+
if (sk) {
|
|
5256
|
+
sum += (ggml_float) expf(sk[i02] - max);
|
|
5257
|
+
}
|
|
5258
|
+
|
|
5259
|
+
sum = 1.0/sum;
|
|
5260
|
+
ggml_vec_scale_f32(ne00, dp, sum);
|
|
5307
5261
|
|
|
5308
5262
|
#ifndef NDEBUG
|
|
5309
|
-
|
|
5310
|
-
|
|
5311
|
-
|
|
5312
|
-
|
|
5263
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5264
|
+
assert(!isnan(dp[i]));
|
|
5265
|
+
assert(!isinf(dp[i]));
|
|
5266
|
+
}
|
|
5313
5267
|
#endif
|
|
5268
|
+
}
|
|
5269
|
+
}
|
|
5314
5270
|
}
|
|
5315
5271
|
}
|
|
5316
5272
|
|
|
@@ -5534,6 +5490,7 @@ void ggml_compute_forward_clamp(
|
|
|
5534
5490
|
case GGML_TYPE_Q5_1:
|
|
5535
5491
|
case GGML_TYPE_Q8_0:
|
|
5536
5492
|
case GGML_TYPE_Q8_1:
|
|
5493
|
+
case GGML_TYPE_MXFP4:
|
|
5537
5494
|
case GGML_TYPE_Q2_K:
|
|
5538
5495
|
case GGML_TYPE_Q3_K:
|
|
5539
5496
|
case GGML_TYPE_Q4_K:
|
|
@@ -5580,276 +5537,123 @@ static void rope_yarn(
|
|
|
5580
5537
|
float theta = theta_interp;
|
|
5581
5538
|
if (ext_factor != 0.0f) {
|
|
5582
5539
|
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
|
5583
|
-
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
5584
|
-
|
|
5585
|
-
// Get n-d magnitude scaling corrected for interpolation
|
|
5586
|
-
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
|
5587
|
-
}
|
|
5588
|
-
*cos_theta = cosf(theta) * mscale;
|
|
5589
|
-
*sin_theta = sinf(theta) * mscale;
|
|
5590
|
-
}
|
|
5591
|
-
|
|
5592
|
-
static void ggml_rope_cache_init(
|
|
5593
|
-
float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5594
|
-
float * cache, float sin_sign, float theta_scale) {
|
|
5595
|
-
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
5596
|
-
float theta = theta_base;
|
|
5597
|
-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
5598
|
-
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
5599
|
-
rope_yarn(
|
|
5600
|
-
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
5601
|
-
);
|
|
5602
|
-
cache[i0 + 1] *= sin_sign;
|
|
5603
|
-
|
|
5604
|
-
theta *= theta_scale;
|
|
5605
|
-
}
|
|
5606
|
-
}
|
|
5607
|
-
|
|
5608
|
-
static void ggml_mrope_cache_init(
|
|
5609
|
-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
|
5610
|
-
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5611
|
-
float * cache, float sin_sign, float theta_scale) {
|
|
5612
|
-
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
5613
|
-
float theta_t = theta_base_t;
|
|
5614
|
-
float theta_h = theta_base_h;
|
|
5615
|
-
float theta_w = theta_base_w;
|
|
5616
|
-
float theta_e = theta_base_e; // extra position id for vision encoder
|
|
5617
|
-
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
|
5618
|
-
int sec_w = sections[1] + sections[0];
|
|
5619
|
-
int sec_e = sections[2] + sec_w;
|
|
5620
|
-
GGML_ASSERT(sect_dims <= ne0);
|
|
5621
|
-
|
|
5622
|
-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
5623
|
-
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
5624
|
-
|
|
5625
|
-
int sector = (i0 / 2) % sect_dims;
|
|
5626
|
-
if (indep_sects) {
|
|
5627
|
-
// compute theta independently for each dim sections
|
|
5628
|
-
// (i.e. reset corresponding theta when `i0` go from one section to another)
|
|
5629
|
-
if (sector == 0) {
|
|
5630
|
-
theta_t = theta_base_t;
|
|
5631
|
-
}
|
|
5632
|
-
else if (sector == sections[0]) {
|
|
5633
|
-
theta_h = theta_base_h;;
|
|
5634
|
-
}
|
|
5635
|
-
else if (sector == sec_w) {
|
|
5636
|
-
theta_w = theta_base_w;
|
|
5637
|
-
}
|
|
5638
|
-
else if (sector == sec_e) {
|
|
5639
|
-
theta_e = theta_base_e;
|
|
5640
|
-
}
|
|
5641
|
-
}
|
|
5642
|
-
|
|
5643
|
-
float theta = theta_t;
|
|
5644
|
-
if (sector >= sections[0] && sector < sec_w) {
|
|
5645
|
-
theta = theta_h;
|
|
5646
|
-
}
|
|
5647
|
-
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
5648
|
-
theta = theta_w;
|
|
5649
|
-
}
|
|
5650
|
-
else if (sector >= sec_w + sections[2]) {
|
|
5651
|
-
theta = theta_e;
|
|
5652
|
-
}
|
|
5653
|
-
|
|
5654
|
-
rope_yarn(
|
|
5655
|
-
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
5656
|
-
);
|
|
5657
|
-
cache[i0 + 1] *= sin_sign;
|
|
5658
|
-
|
|
5659
|
-
theta_t *= theta_scale;
|
|
5660
|
-
theta_w *= theta_scale;
|
|
5661
|
-
theta_h *= theta_scale;
|
|
5662
|
-
theta_e *= theta_scale;
|
|
5663
|
-
}
|
|
5664
|
-
}
|
|
5665
|
-
|
|
5666
|
-
static void ggml_compute_forward_rope_f32(
|
|
5667
|
-
const ggml_compute_params * params,
|
|
5668
|
-
ggml_tensor * dst,
|
|
5669
|
-
const bool forward) {
|
|
5670
|
-
|
|
5671
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
5672
|
-
const ggml_tensor * src1 = dst->src[1];
|
|
5673
|
-
const ggml_tensor * src2 = dst->src[2];
|
|
5674
|
-
|
|
5675
|
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5676
|
-
int sections[4];
|
|
5677
|
-
|
|
5678
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
5679
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
5680
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
5681
|
-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
5682
|
-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5683
|
-
|
|
5684
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
5685
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
5686
|
-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
5687
|
-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
5688
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
5689
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
5690
|
-
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
5691
|
-
|
|
5692
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
|
5693
|
-
|
|
5694
|
-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5695
|
-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5696
|
-
|
|
5697
|
-
GGML_ASSERT(nb00 == sizeof(float));
|
|
5698
|
-
|
|
5699
|
-
const int ith = params->ith;
|
|
5700
|
-
const int nth = params->nth;
|
|
5701
|
-
|
|
5702
|
-
const int nr = ggml_nrows(dst);
|
|
5703
|
-
|
|
5704
|
-
GGML_ASSERT(n_dims <= ne0);
|
|
5705
|
-
GGML_ASSERT(n_dims % 2 == 0);
|
|
5706
|
-
|
|
5707
|
-
// rows per thread
|
|
5708
|
-
const int dr = (nr + nth - 1)/nth;
|
|
5709
|
-
|
|
5710
|
-
// row range for this thread
|
|
5711
|
-
const int ir0 = dr*ith;
|
|
5712
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
5713
|
-
|
|
5714
|
-
// row index used to determine which thread to use
|
|
5715
|
-
int ir = 0;
|
|
5716
|
-
|
|
5717
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
5718
|
-
|
|
5719
|
-
float corr_dims[2];
|
|
5720
|
-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5721
|
-
|
|
5722
|
-
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
5723
|
-
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
|
5724
|
-
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
5725
|
-
|
|
5726
|
-
if (is_mrope) {
|
|
5727
|
-
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5728
|
-
}
|
|
5729
|
-
|
|
5730
|
-
if (is_vision) {
|
|
5731
|
-
GGML_ASSERT(n_dims == ne0/2);
|
|
5732
|
-
}
|
|
5733
|
-
|
|
5734
|
-
const float * freq_factors = NULL;
|
|
5735
|
-
if (src2 != NULL) {
|
|
5736
|
-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
|
5737
|
-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
|
5738
|
-
freq_factors = (const float *) src2->data;
|
|
5739
|
-
}
|
|
5740
|
-
|
|
5741
|
-
// backward process uses inverse rotation by cos and sin.
|
|
5742
|
-
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
5743
|
-
// this essentially just switches the sign of sin.
|
|
5744
|
-
const float sin_sign = forward ? 1.0f : -1.0f;
|
|
5745
|
-
|
|
5746
|
-
const int32_t * pos = (const int32_t *) src1->data;
|
|
5747
|
-
|
|
5748
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5749
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5750
|
-
|
|
5751
|
-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5752
|
-
if (!is_mrope) {
|
|
5753
|
-
const int64_t p = pos[i2];
|
|
5754
|
-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5755
|
-
}
|
|
5756
|
-
else {
|
|
5757
|
-
const int64_t p_t = pos[i2];
|
|
5758
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
5759
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5760
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5761
|
-
ggml_mrope_cache_init(
|
|
5762
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5763
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5764
|
-
}
|
|
5765
|
-
|
|
5766
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5767
|
-
if (ir++ < ir0) continue;
|
|
5768
|
-
if (ir > ir1) break;
|
|
5769
|
-
|
|
5770
|
-
if (is_neox || is_mrope) {
|
|
5771
|
-
if (is_vision){
|
|
5772
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5773
|
-
const int64_t ic = i0/2;
|
|
5774
|
-
|
|
5775
|
-
const float cos_theta = cache[i0 + 0];
|
|
5776
|
-
const float sin_theta = cache[i0 + 1];
|
|
5777
|
-
|
|
5778
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5779
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5540
|
+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
5780
5541
|
|
|
5781
|
-
|
|
5782
|
-
|
|
5542
|
+
// Get n-d magnitude scaling corrected for interpolation
|
|
5543
|
+
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
|
5544
|
+
}
|
|
5545
|
+
*cos_theta = cosf(theta) * mscale;
|
|
5546
|
+
*sin_theta = sinf(theta) * mscale;
|
|
5547
|
+
}
|
|
5783
5548
|
|
|
5784
|
-
|
|
5785
|
-
|
|
5786
|
-
|
|
5787
|
-
|
|
5788
|
-
|
|
5789
|
-
|
|
5549
|
+
static void ggml_rope_cache_init(
|
|
5550
|
+
float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5551
|
+
float * cache, float sin_sign, float theta_scale) {
|
|
5552
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
5553
|
+
float theta = theta_base;
|
|
5554
|
+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
5555
|
+
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
5556
|
+
rope_yarn(
|
|
5557
|
+
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
5558
|
+
);
|
|
5559
|
+
cache[i0 + 1] *= sin_sign;
|
|
5790
5560
|
|
|
5791
|
-
|
|
5792
|
-
|
|
5561
|
+
theta *= theta_scale;
|
|
5562
|
+
}
|
|
5563
|
+
}
|
|
5793
5564
|
|
|
5794
|
-
|
|
5795
|
-
|
|
5565
|
+
static void ggml_mrope_cache_init(
|
|
5566
|
+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
|
5567
|
+
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5568
|
+
float * cache, float sin_sign, float theta_scale) {
|
|
5569
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
5570
|
+
float theta_t = theta_base_t;
|
|
5571
|
+
float theta_h = theta_base_h;
|
|
5572
|
+
float theta_w = theta_base_w;
|
|
5573
|
+
float theta_e = theta_base_e; // extra position id for vision encoder
|
|
5574
|
+
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
|
5575
|
+
int sec_w = sections[1] + sections[0];
|
|
5576
|
+
int sec_e = sections[2] + sec_w;
|
|
5577
|
+
GGML_ASSERT(sect_dims <= ne0);
|
|
5796
5578
|
|
|
5797
|
-
|
|
5798
|
-
|
|
5579
|
+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
5580
|
+
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
5799
5581
|
|
|
5800
|
-
|
|
5801
|
-
|
|
5802
|
-
|
|
5803
|
-
|
|
5804
|
-
|
|
5805
|
-
|
|
5806
|
-
|
|
5807
|
-
|
|
5582
|
+
int sector = (i0 / 2) % sect_dims;
|
|
5583
|
+
if (indep_sects) {
|
|
5584
|
+
// compute theta independently for each dim sections
|
|
5585
|
+
// (i.e. reset corresponding theta when `i0` go from one section to another)
|
|
5586
|
+
if (sector == 0) {
|
|
5587
|
+
theta_t = theta_base_t;
|
|
5588
|
+
}
|
|
5589
|
+
else if (sector == sections[0]) {
|
|
5590
|
+
theta_h = theta_base_h;;
|
|
5591
|
+
}
|
|
5592
|
+
else if (sector == sec_w) {
|
|
5593
|
+
theta_w = theta_base_w;
|
|
5594
|
+
}
|
|
5595
|
+
else if (sector == sec_e) {
|
|
5596
|
+
theta_e = theta_base_e;
|
|
5597
|
+
}
|
|
5598
|
+
}
|
|
5808
5599
|
|
|
5809
|
-
|
|
5810
|
-
|
|
5600
|
+
float theta = theta_t;
|
|
5601
|
+
if (is_imrope) { // qwen3vl apply interleaved mrope
|
|
5602
|
+
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
|
5603
|
+
theta = theta_h;
|
|
5604
|
+
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
|
5605
|
+
theta = theta_w;
|
|
5606
|
+
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
|
5607
|
+
theta = theta_t;
|
|
5608
|
+
} else {
|
|
5609
|
+
theta = theta_e;
|
|
5610
|
+
}
|
|
5611
|
+
} else {
|
|
5612
|
+
if (sector >= sections[0] && sector < sec_w) {
|
|
5613
|
+
theta = theta_h;
|
|
5614
|
+
}
|
|
5615
|
+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
5616
|
+
theta = theta_w;
|
|
5617
|
+
}
|
|
5618
|
+
else if (sector >= sec_w + sections[2]) {
|
|
5619
|
+
theta = theta_e;
|
|
5620
|
+
}
|
|
5621
|
+
}
|
|
5811
5622
|
|
|
5812
|
-
|
|
5813
|
-
|
|
5623
|
+
rope_yarn(
|
|
5624
|
+
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
5625
|
+
);
|
|
5626
|
+
cache[i0 + 1] *= sin_sign;
|
|
5814
5627
|
|
|
5815
|
-
|
|
5816
|
-
|
|
5817
|
-
|
|
5818
|
-
|
|
5628
|
+
theta_t *= theta_scale;
|
|
5629
|
+
theta_w *= theta_scale;
|
|
5630
|
+
theta_h *= theta_scale;
|
|
5631
|
+
theta_e *= theta_scale;
|
|
5632
|
+
}
|
|
5633
|
+
}
|
|
5819
5634
|
|
|
5820
|
-
if (is_vision) {
|
|
5821
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5822
|
-
const int64_t ic = i0/2;
|
|
5823
5635
|
|
|
5824
|
-
|
|
5825
|
-
|
|
5636
|
+
template<typename T>
|
|
5637
|
+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
|
|
5638
|
+
for (int64_t i0 = 0; i0 < n; i0 += 2) {
|
|
5639
|
+
const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
|
|
5826
5640
|
|
|
5827
|
-
|
|
5828
|
-
|
|
5641
|
+
const float cos_theta = cache[i0 + 0];
|
|
5642
|
+
const float sin_theta = cache[i0 + 1];
|
|
5829
5643
|
|
|
5830
|
-
|
|
5831
|
-
|
|
5644
|
+
const T * const src = src_data + ic;
|
|
5645
|
+
T * dst = dst_data + ic;
|
|
5832
5646
|
|
|
5833
|
-
|
|
5834
|
-
|
|
5835
|
-
}
|
|
5836
|
-
} else {
|
|
5837
|
-
// fill the remain channels with data from src tensor
|
|
5838
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
5839
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5840
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5647
|
+
const float x0 = type_conversion_table<T>::to_f32(src[0]);
|
|
5648
|
+
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
|
|
5841
5649
|
|
|
5842
|
-
|
|
5843
|
-
|
|
5844
|
-
|
|
5845
|
-
}
|
|
5846
|
-
}
|
|
5847
|
-
}
|
|
5848
|
-
}
|
|
5650
|
+
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
|
|
5651
|
+
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
|
|
5652
|
+
}
|
|
5849
5653
|
}
|
|
5850
5654
|
|
|
5851
|
-
|
|
5852
|
-
static void
|
|
5655
|
+
template<typename T> //float or ggml_fp16_t
|
|
5656
|
+
static void ggml_compute_forward_rope_flt(
|
|
5853
5657
|
const ggml_compute_params * params,
|
|
5854
5658
|
ggml_tensor * dst,
|
|
5855
5659
|
const bool forward) {
|
|
@@ -5858,6 +5662,9 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5858
5662
|
const ggml_tensor * src1 = dst->src[1];
|
|
5859
5663
|
const ggml_tensor * src2 = dst->src[2];
|
|
5860
5664
|
|
|
5665
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
5666
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
5667
|
+
|
|
5861
5668
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
5862
5669
|
int sections[4];
|
|
5863
5670
|
|
|
@@ -5866,6 +5673,7 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5866
5673
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
5867
5674
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
5868
5675
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5676
|
+
|
|
5869
5677
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
5870
5678
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
5871
5679
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
@@ -5874,13 +5682,13 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5874
5682
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
5875
5683
|
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
5876
5684
|
|
|
5877
|
-
|
|
5878
5685
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
5879
5686
|
|
|
5880
5687
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
5881
5688
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
5882
5689
|
|
|
5883
|
-
GGML_ASSERT(nb0 ==
|
|
5690
|
+
GGML_ASSERT(nb0 == nb00);
|
|
5691
|
+
GGML_ASSERT(nb0 == sizeof(T));
|
|
5884
5692
|
|
|
5885
5693
|
const int ith = params->ith;
|
|
5886
5694
|
const int nth = params->nth;
|
|
@@ -5905,11 +5713,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5905
5713
|
float corr_dims[2];
|
|
5906
5714
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
5907
5715
|
|
|
5908
|
-
const bool
|
|
5909
|
-
const bool
|
|
5716
|
+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
|
5717
|
+
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
|
5910
5718
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
5911
5719
|
|
|
5912
|
-
if (
|
|
5720
|
+
if (mrope_used) {
|
|
5913
5721
|
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
5914
5722
|
}
|
|
5915
5723
|
|
|
@@ -5931,11 +5739,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5931
5739
|
|
|
5932
5740
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
5933
5741
|
|
|
5934
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
5935
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
5742
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5743
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
5936
5744
|
|
|
5937
5745
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
5938
|
-
if (!
|
|
5746
|
+
if (!mrope_used) {
|
|
5939
5747
|
const int64_t p = pos[i2];
|
|
5940
5748
|
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5941
5749
|
}
|
|
@@ -5945,90 +5753,44 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5945
5753
|
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
5946
5754
|
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
5947
5755
|
ggml_mrope_cache_init(
|
|
5948
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5756
|
+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
5949
5757
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
5950
5758
|
}
|
|
5951
5759
|
|
|
5952
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
5760
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
5953
5761
|
if (ir++ < ir0) continue;
|
|
5954
5762
|
if (ir > ir1) break;
|
|
5955
5763
|
|
|
5956
|
-
|
|
5957
|
-
|
|
5958
|
-
|
|
5959
|
-
|
|
5960
|
-
|
|
5961
|
-
|
|
5962
|
-
|
|
5963
|
-
|
|
5964
|
-
|
|
5965
|
-
|
|
5966
|
-
|
|
5967
|
-
|
|
5968
|
-
|
|
5969
|
-
|
|
5970
|
-
|
|
5971
|
-
|
|
5972
|
-
|
|
5973
|
-
|
|
5974
|
-
|
|
5975
|
-
|
|
5976
|
-
|
|
5977
|
-
const float cos_theta = cache[i0 + 0];
|
|
5978
|
-
const float sin_theta = cache[i0 + 1];
|
|
5979
|
-
|
|
5980
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5981
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5982
|
-
|
|
5983
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5984
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
|
5985
|
-
|
|
5986
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
5987
|
-
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5988
|
-
}
|
|
5989
|
-
}
|
|
5990
|
-
} else {
|
|
5991
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
5992
|
-
const float cos_theta = cache[i0 + 0];
|
|
5993
|
-
const float sin_theta = cache[i0 + 1];
|
|
5994
|
-
|
|
5995
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5996
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5997
|
-
|
|
5998
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
5999
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
|
|
6000
|
-
|
|
6001
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6002
|
-
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
6003
|
-
}
|
|
6004
|
-
}
|
|
6005
|
-
|
|
6006
|
-
if (is_vision) {
|
|
6007
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
6008
|
-
const int64_t ic = i0/2;
|
|
6009
|
-
|
|
6010
|
-
const float cos_theta = cache[i0 + 0];
|
|
6011
|
-
const float sin_theta = cache[i0 + 1];
|
|
6012
|
-
|
|
6013
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
6014
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
6015
|
-
|
|
6016
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
6017
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
6018
|
-
|
|
6019
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6020
|
-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
6021
|
-
}
|
|
6022
|
-
} else {
|
|
5764
|
+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
5765
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
5766
|
+
|
|
5767
|
+
switch (mode) {
|
|
5768
|
+
case GGML_ROPE_TYPE_NORMAL:
|
|
5769
|
+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
|
|
5770
|
+
break;
|
|
5771
|
+
case GGML_ROPE_TYPE_NEOX:
|
|
5772
|
+
case GGML_ROPE_TYPE_MROPE:
|
|
5773
|
+
case GGML_ROPE_TYPE_IMROPE:
|
|
5774
|
+
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
|
|
5775
|
+
break;
|
|
5776
|
+
case GGML_ROPE_TYPE_VISION:
|
|
5777
|
+
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
|
|
5778
|
+
break;
|
|
5779
|
+
default:
|
|
5780
|
+
GGML_ABORT("rope type not supported");
|
|
5781
|
+
}
|
|
5782
|
+
|
|
5783
|
+
if (!is_vision) {
|
|
5784
|
+
// fill the remain channels with data from src tensor
|
|
6023
5785
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
6024
|
-
const
|
|
6025
|
-
|
|
5786
|
+
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5787
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
6026
5788
|
|
|
6027
5789
|
dst_data[0] = src[0];
|
|
6028
5790
|
dst_data[1] = src[1];
|
|
6029
5791
|
}
|
|
6030
5792
|
}
|
|
6031
|
-
}
|
|
5793
|
+
} //attn-heads
|
|
6032
5794
|
}
|
|
6033
5795
|
}
|
|
6034
5796
|
}
|
|
@@ -6042,11 +5804,11 @@ void ggml_compute_forward_rope(
|
|
|
6042
5804
|
switch (src0->type) {
|
|
6043
5805
|
case GGML_TYPE_F16:
|
|
6044
5806
|
{
|
|
6045
|
-
|
|
5807
|
+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
|
|
6046
5808
|
} break;
|
|
6047
5809
|
case GGML_TYPE_F32:
|
|
6048
5810
|
{
|
|
6049
|
-
|
|
5811
|
+
ggml_compute_forward_rope_flt<float>(params, dst, true);
|
|
6050
5812
|
} break;
|
|
6051
5813
|
default:
|
|
6052
5814
|
{
|
|
@@ -6066,11 +5828,11 @@ void ggml_compute_forward_rope_back(
|
|
|
6066
5828
|
switch (src0->type) {
|
|
6067
5829
|
case GGML_TYPE_F16:
|
|
6068
5830
|
{
|
|
6069
|
-
|
|
5831
|
+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
|
|
6070
5832
|
} break;
|
|
6071
5833
|
case GGML_TYPE_F32:
|
|
6072
5834
|
{
|
|
6073
|
-
|
|
5835
|
+
ggml_compute_forward_rope_flt<float>(params, dst, false);
|
|
6074
5836
|
} break;
|
|
6075
5837
|
default:
|
|
6076
5838
|
{
|
|
@@ -6477,68 +6239,251 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
6477
6239
|
const int ith = params->ith;
|
|
6478
6240
|
const int nth = params->nth;
|
|
6479
6241
|
|
|
6480
|
-
const int64_t N = is_2D ? ne3 : ne2;
|
|
6481
|
-
const int64_t IC = is_2D ? ne2 : ne1;
|
|
6482
|
-
const int64_t IH = is_2D ? ne1 : 1;
|
|
6483
|
-
const int64_t IW = ne0;
|
|
6242
|
+
const int64_t N = is_2D ? ne3 : ne2;
|
|
6243
|
+
const int64_t IC = is_2D ? ne2 : ne1;
|
|
6244
|
+
const int64_t IH = is_2D ? ne1 : 1;
|
|
6245
|
+
const int64_t IW = ne0;
|
|
6246
|
+
|
|
6247
|
+
const int64_t KH = is_2D ? ne11 : 1;
|
|
6248
|
+
const int64_t KW = ne10;
|
|
6249
|
+
|
|
6250
|
+
const int64_t OH = is_2D ? ne02 : 1;
|
|
6251
|
+
const int64_t OW = ne01;
|
|
6252
|
+
|
|
6253
|
+
int ofs0 = is_2D ? nb3 : nb2;
|
|
6254
|
+
int ofs1 = is_2D ? nb2 : nb1;
|
|
6255
|
+
|
|
6256
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
6257
|
+
|
|
6258
|
+
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
6259
|
+
{
|
|
6260
|
+
float * const wdata = (float *) dst->data;
|
|
6261
|
+
|
|
6262
|
+
for (int64_t in = 0; in < N; in++) {
|
|
6263
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
6264
|
+
for (int64_t iih = 0; iih < IH; iih++) {
|
|
6265
|
+
for (int64_t iiw = 0; iiw < IW; iiw++) {
|
|
6266
|
+
|
|
6267
|
+
// micro kernel
|
|
6268
|
+
float grad = 0.0f;
|
|
6269
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
6270
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
6271
|
+
// For s0 > 1 some values were skipped over in the forward pass.
|
|
6272
|
+
// These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
|
|
6273
|
+
const int64_t tmpw = (iiw + p0 - ikw*d0);
|
|
6274
|
+
if (tmpw % s0 != 0) {
|
|
6275
|
+
continue;
|
|
6276
|
+
}
|
|
6277
|
+
const int64_t iow = tmpw / s0;
|
|
6278
|
+
|
|
6279
|
+
// Equivalent logic as above except for s1.
|
|
6280
|
+
int64_t ioh;
|
|
6281
|
+
if (is_2D) {
|
|
6282
|
+
const int64_t tmph = iih + p1 - ikh*d1;
|
|
6283
|
+
|
|
6284
|
+
if (tmph % s1 != 0) {
|
|
6285
|
+
continue;
|
|
6286
|
+
}
|
|
6287
|
+
|
|
6288
|
+
ioh = tmph / s1;
|
|
6289
|
+
} else {
|
|
6290
|
+
ioh = 0;
|
|
6291
|
+
}
|
|
6292
|
+
|
|
6293
|
+
if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
|
|
6294
|
+
continue;
|
|
6295
|
+
}
|
|
6296
|
+
|
|
6297
|
+
const float * const grad_in = (const float *) src0->data
|
|
6298
|
+
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
6299
|
+
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
|
6300
|
+
}
|
|
6301
|
+
}
|
|
6302
|
+
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
|
6303
|
+
dst_data[iih*IW + iiw] = grad;
|
|
6304
|
+
}
|
|
6305
|
+
}
|
|
6306
|
+
}
|
|
6307
|
+
}
|
|
6308
|
+
}
|
|
6309
|
+
}
|
|
6310
|
+
|
|
6311
|
+
|
|
6312
|
+
// ggml_compute_forward_im2col_3d_f16
|
|
6313
|
+
// src0: kernel [OC*IC, KD, KH, KW]
|
|
6314
|
+
// src1: image [N*IC, ID, IH, IW]
|
|
6315
|
+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6316
|
+
static void ggml_compute_forward_im2col_3d_f16(
|
|
6317
|
+
const ggml_compute_params * params,
|
|
6318
|
+
ggml_tensor * dst) {
|
|
6319
|
+
|
|
6320
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
6321
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
6322
|
+
|
|
6323
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
6324
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6325
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
|
6326
|
+
|
|
6327
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
6328
|
+
|
|
6329
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
6330
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
6331
|
+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
|
6332
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
|
6333
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
|
6334
|
+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
|
6335
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
|
6336
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
|
6337
|
+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
|
6338
|
+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
|
6339
|
+
|
|
6340
|
+
|
|
6341
|
+
const int ith = params->ith;
|
|
6342
|
+
const int nth = params->nth;
|
|
6343
|
+
|
|
6344
|
+
const int64_t N = ne13 / IC;
|
|
6345
|
+
const int64_t ID = ne12;
|
|
6346
|
+
const int64_t IH = ne11;
|
|
6347
|
+
const int64_t IW = ne10;
|
|
6348
|
+
|
|
6349
|
+
const int64_t OC = ne03 / IC;
|
|
6350
|
+
GGML_UNUSED(OC);
|
|
6351
|
+
const int64_t KD = ne02;
|
|
6352
|
+
const int64_t KH = ne01;
|
|
6353
|
+
const int64_t KW = ne00;
|
|
6354
|
+
|
|
6355
|
+
const int64_t OD = ne3 / N;
|
|
6356
|
+
const int64_t OH = ne2;
|
|
6357
|
+
const int64_t OW = ne1;
|
|
6358
|
+
const int64_t OH_OW = OH*OW;
|
|
6359
|
+
const int64_t KD_KH_KW = KD*KH*KW;
|
|
6360
|
+
const int64_t KH_KW = KH*KW;
|
|
6361
|
+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
|
6362
|
+
|
|
6363
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
|
6364
|
+
|
|
6365
|
+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6366
|
+
{
|
|
6367
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
|
|
6368
|
+
|
|
6369
|
+
for (int64_t in = 0; in < N; in++) {
|
|
6370
|
+
for (int64_t iod = 0; iod < OD; iod++) {
|
|
6371
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
|
6372
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
|
6373
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
6374
|
+
|
|
6375
|
+
// micro kernel
|
|
6376
|
+
ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
|
6377
|
+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
|
|
6378
|
+
|
|
6379
|
+
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
|
6380
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
6381
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
6382
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
|
6383
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6384
|
+
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6385
|
+
|
|
6386
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
6387
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6388
|
+
} else {
|
|
6389
|
+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
6390
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
|
|
6391
|
+
}
|
|
6392
|
+
}
|
|
6393
|
+
}
|
|
6394
|
+
}
|
|
6395
|
+
}
|
|
6396
|
+
}
|
|
6397
|
+
}
|
|
6398
|
+
}
|
|
6399
|
+
}
|
|
6400
|
+
}
|
|
6401
|
+
}
|
|
6402
|
+
|
|
6403
|
+
// ggml_compute_forward_im2col_3d_f32
|
|
6404
|
+
// src0: kernel [OC*IC, KD, KH, KW]
|
|
6405
|
+
// src1: image [N*IC, ID, IH, IW]
|
|
6406
|
+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6407
|
+
static void ggml_compute_forward_im2col_3d_f32(
|
|
6408
|
+
const ggml_compute_params * params,
|
|
6409
|
+
ggml_tensor * dst) {
|
|
6410
|
+
|
|
6411
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
6412
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
6413
|
+
|
|
6414
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6415
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
6416
|
+
|
|
6417
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
6418
|
+
|
|
6419
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
6420
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
6421
|
+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
|
6422
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
|
6423
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
|
6424
|
+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
|
6425
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
|
6426
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
|
6427
|
+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
|
6428
|
+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
|
6429
|
+
|
|
6430
|
+
|
|
6431
|
+
const int ith = params->ith;
|
|
6432
|
+
const int nth = params->nth;
|
|
6433
|
+
|
|
6434
|
+
const int64_t N = ne13 / IC;
|
|
6435
|
+
const int64_t ID = ne12;
|
|
6436
|
+
const int64_t IH = ne11;
|
|
6437
|
+
const int64_t IW = ne10;
|
|
6484
6438
|
|
|
6485
|
-
const int64_t
|
|
6486
|
-
|
|
6439
|
+
const int64_t OC = ne03 / IC;
|
|
6440
|
+
GGML_UNUSED(OC);
|
|
6441
|
+
const int64_t KD = ne02;
|
|
6442
|
+
const int64_t KH = ne01;
|
|
6443
|
+
const int64_t KW = ne00;
|
|
6487
6444
|
|
|
6488
|
-
const int64_t
|
|
6489
|
-
const int64_t
|
|
6445
|
+
const int64_t OD = ne3 / N;
|
|
6446
|
+
const int64_t OH = ne2;
|
|
6447
|
+
const int64_t OW = ne1;
|
|
6490
6448
|
|
|
6491
|
-
|
|
6492
|
-
|
|
6449
|
+
const int64_t OH_OW = OH*OW;
|
|
6450
|
+
const int64_t KD_KH_KW = KD*KH*KW;
|
|
6451
|
+
const int64_t KH_KW = KH*KW;
|
|
6452
|
+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
|
6493
6453
|
|
|
6494
|
-
GGML_ASSERT(
|
|
6454
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
|
6495
6455
|
|
|
6496
|
-
// im2col: [N,
|
|
6456
|
+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6497
6457
|
{
|
|
6498
6458
|
float * const wdata = (float *) dst->data;
|
|
6499
6459
|
|
|
6500
6460
|
for (int64_t in = 0; in < N; in++) {
|
|
6501
|
-
for (int64_t
|
|
6502
|
-
for (int64_t
|
|
6503
|
-
for (int64_t
|
|
6504
|
-
|
|
6505
|
-
|
|
6506
|
-
|
|
6507
|
-
|
|
6508
|
-
|
|
6509
|
-
|
|
6510
|
-
|
|
6511
|
-
|
|
6512
|
-
|
|
6513
|
-
|
|
6514
|
-
|
|
6515
|
-
|
|
6516
|
-
|
|
6517
|
-
|
|
6518
|
-
|
|
6519
|
-
|
|
6520
|
-
|
|
6521
|
-
|
|
6522
|
-
|
|
6523
|
-
continue;
|
|
6461
|
+
for (int64_t iod = 0; iod < OD; iod++) {
|
|
6462
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
|
6463
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
|
6464
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
6465
|
+
|
|
6466
|
+
// micro kernel
|
|
6467
|
+
float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
|
6468
|
+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
|
|
6469
|
+
|
|
6470
|
+
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
|
6471
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
6472
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
6473
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
|
6474
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6475
|
+
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6476
|
+
|
|
6477
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
|
6478
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6479
|
+
} else {
|
|
6480
|
+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
6481
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
|
|
6482
|
+
}
|
|
6524
6483
|
}
|
|
6525
|
-
|
|
6526
|
-
ioh = tmph / s1;
|
|
6527
|
-
} else {
|
|
6528
|
-
ioh = 0;
|
|
6529
|
-
}
|
|
6530
|
-
|
|
6531
|
-
if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
|
|
6532
|
-
continue;
|
|
6533
6484
|
}
|
|
6534
|
-
|
|
6535
|
-
const float * const grad_in = (const float *) src0->data
|
|
6536
|
-
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
6537
|
-
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
|
6538
6485
|
}
|
|
6539
6486
|
}
|
|
6540
|
-
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
|
6541
|
-
dst_data[iih*IW + iiw] = grad;
|
|
6542
6487
|
}
|
|
6543
6488
|
}
|
|
6544
6489
|
}
|
|
@@ -6546,6 +6491,26 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
6546
6491
|
}
|
|
6547
6492
|
}
|
|
6548
6493
|
|
|
6494
|
+
|
|
6495
|
+
void ggml_compute_forward_im2col_3d(
|
|
6496
|
+
const ggml_compute_params * params,
|
|
6497
|
+
ggml_tensor * dst) {
|
|
6498
|
+
switch (dst->type) {
|
|
6499
|
+
case GGML_TYPE_F16:
|
|
6500
|
+
{
|
|
6501
|
+
ggml_compute_forward_im2col_3d_f16(params, dst);
|
|
6502
|
+
} break;
|
|
6503
|
+
case GGML_TYPE_F32:
|
|
6504
|
+
{
|
|
6505
|
+
ggml_compute_forward_im2col_3d_f32(params, dst);
|
|
6506
|
+
} break;
|
|
6507
|
+
default:
|
|
6508
|
+
{
|
|
6509
|
+
GGML_ABORT("fatal error");
|
|
6510
|
+
}
|
|
6511
|
+
}
|
|
6512
|
+
}
|
|
6513
|
+
|
|
6549
6514
|
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
6550
6515
|
void * a, void * b, float * c) {
|
|
6551
6516
|
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
|
@@ -6589,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
|
|
|
6589
6554
|
ggml_compute_forward_mul_mat(params, &dst);
|
|
6590
6555
|
}
|
|
6591
6556
|
|
|
6557
|
+
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
|
6558
|
+
return (coord + size) % size; // adding size avoids negative number weirdness
|
|
6559
|
+
}
|
|
6560
|
+
|
|
6592
6561
|
// ggml_compute_forward_conv_2d
|
|
6593
6562
|
|
|
6563
|
+
|
|
6594
6564
|
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
|
6595
6565
|
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
|
6596
6566
|
const ggml_tensor * src, // [W, H, C, N]
|
|
@@ -6726,6 +6696,148 @@ void ggml_compute_forward_conv_2d(
|
|
|
6726
6696
|
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
|
6727
6697
|
}
|
|
6728
6698
|
|
|
6699
|
+
// ggml_compute_forward_conv_3d
|
|
6700
|
+
|
|
6701
|
+
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
|
|
6702
|
+
const ggml_tensor * kernel,
|
|
6703
|
+
const ggml_tensor * src,
|
|
6704
|
+
ggml_tensor * dst,
|
|
6705
|
+
ggml_type kernel_type) {
|
|
6706
|
+
|
|
6707
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
|
6708
|
+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
|
6709
|
+
GGML_ASSERT(kernel->type == kernel_type);
|
|
6710
|
+
|
|
6711
|
+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
|
6712
|
+
|
|
6713
|
+
const int32_t s0 = dst->op_params[0];
|
|
6714
|
+
const int32_t s1 = dst->op_params[1];
|
|
6715
|
+
const int32_t s2 = dst->op_params[2];
|
|
6716
|
+
const int32_t p0 = dst->op_params[3];
|
|
6717
|
+
const int32_t p1 = dst->op_params[4];
|
|
6718
|
+
const int32_t p2 = dst->op_params[5];
|
|
6719
|
+
const int32_t d0 = dst->op_params[6];
|
|
6720
|
+
const int32_t d1 = dst->op_params[7];
|
|
6721
|
+
const int32_t d2 = dst->op_params[8];
|
|
6722
|
+
const int32_t c = dst->op_params[9];
|
|
6723
|
+
const int32_t n = dst->op_params[10];
|
|
6724
|
+
const int32_t oc = dst->op_params[11];
|
|
6725
|
+
|
|
6726
|
+
const int64_t src_w = src->ne[0];
|
|
6727
|
+
const int64_t src_h = src->ne[1];
|
|
6728
|
+
const int64_t src_d = src->ne[2];
|
|
6729
|
+
const int64_t knl_w = kernel->ne[0];
|
|
6730
|
+
const int64_t knl_h = kernel->ne[1];
|
|
6731
|
+
const int64_t knl_d = kernel->ne[2];
|
|
6732
|
+
const int64_t dst_w = dst->ne[0];
|
|
6733
|
+
const int64_t dst_h = dst->ne[1];
|
|
6734
|
+
const int64_t dst_d = dst->ne[2];
|
|
6735
|
+
|
|
6736
|
+
const float * src_data = (float *) src->data;
|
|
6737
|
+
void * knl_data = kernel->data;
|
|
6738
|
+
float * dst_data = (float *) dst->data;
|
|
6739
|
+
|
|
6740
|
+
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
|
6741
|
+
const int64_t knl_n_total = knl_n_per_channel * c;
|
|
6742
|
+
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
|
6743
|
+
|
|
6744
|
+
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
|
6745
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
|
6746
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
|
6747
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
|
6748
|
+
|
|
6749
|
+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
|
6750
|
+
|
|
6751
|
+
void * tmp = params->wdata;
|
|
6752
|
+
|
|
6753
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
|
6754
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
|
6755
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
|
6756
|
+
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
|
6757
|
+
|
|
6758
|
+
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
|
6759
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
|
6760
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
|
6761
|
+
|
|
6762
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
|
6763
|
+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
|
6764
|
+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
|
6765
|
+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
|
6766
|
+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
|
6767
|
+
const int64_t dst_y = p_in_depth / dst_w;
|
|
6768
|
+
const int64_t dst_x = p_in_depth % dst_w;
|
|
6769
|
+
|
|
6770
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
|
6771
|
+
|
|
6772
|
+
for (int64_t ic = 0; ic < c; ++ic) {
|
|
6773
|
+
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
|
6774
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
|
6775
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
|
6776
|
+
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
|
6777
|
+
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
|
6778
|
+
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
|
6779
|
+
|
|
6780
|
+
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
|
6781
|
+
|
|
6782
|
+
float src_val;
|
|
6783
|
+
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
|
6784
|
+
src_val = 0.0f;
|
|
6785
|
+
} else {
|
|
6786
|
+
const int64_t cn_idx = batch_idx * c + ic;
|
|
6787
|
+
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
|
6788
|
+
src_val = *src_ptr;
|
|
6789
|
+
}
|
|
6790
|
+
|
|
6791
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
|
6792
|
+
if (kernel_type == GGML_TYPE_F32) {
|
|
6793
|
+
*(float *)element_ptr = src_val;
|
|
6794
|
+
} else if (kernel_type == GGML_TYPE_F16) {
|
|
6795
|
+
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
|
6796
|
+
}
|
|
6797
|
+
}
|
|
6798
|
+
}
|
|
6799
|
+
}
|
|
6800
|
+
}
|
|
6801
|
+
}
|
|
6802
|
+
|
|
6803
|
+
ggml_barrier(params->threadpool);
|
|
6804
|
+
|
|
6805
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
|
6806
|
+
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
|
6807
|
+
|
|
6808
|
+
ggml_barrier(params->threadpool);
|
|
6809
|
+
|
|
6810
|
+
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
|
6811
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
|
6812
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
|
6813
|
+
|
|
6814
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
|
6815
|
+
const int64_t p = patch_start_batch + i;
|
|
6816
|
+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
|
6817
|
+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
|
6818
|
+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
|
6819
|
+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
|
6820
|
+
const int64_t dst_y = p_in_depth / dst_w;
|
|
6821
|
+
const int64_t dst_x = p_in_depth % dst_w;
|
|
6822
|
+
|
|
6823
|
+
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
|
6824
|
+
const float value = gemm_output[i * oc + ioc];
|
|
6825
|
+
const int64_t ocn_idx = batch_idx * oc + ioc;
|
|
6826
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
|
6827
|
+
*dst_ptr = value;
|
|
6828
|
+
}
|
|
6829
|
+
}
|
|
6830
|
+
}
|
|
6831
|
+
}
|
|
6832
|
+
|
|
6833
|
+
void ggml_compute_forward_conv_3d(
|
|
6834
|
+
const ggml_compute_params * params,
|
|
6835
|
+
ggml_tensor * dst) {
|
|
6836
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
6837
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
6838
|
+
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
|
6839
|
+
}
|
|
6840
|
+
|
|
6729
6841
|
// ggml_compute_forward_conv_transpose_2d
|
|
6730
6842
|
|
|
6731
6843
|
void ggml_compute_forward_conv_transpose_2d(
|
|
@@ -6857,7 +6969,11 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
|
|
|
6857
6969
|
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
|
|
6858
6970
|
|
|
6859
6971
|
#ifdef GGML_SIMD
|
|
6860
|
-
|
|
6972
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
6973
|
+
const int64_t pkg_size = svcntw();
|
|
6974
|
+
#else
|
|
6975
|
+
const int64_t pkg_size = GGML_F32_EPR;
|
|
6976
|
+
#endif
|
|
6861
6977
|
const int64_t pkg_count = c / pkg_size;
|
|
6862
6978
|
const int64_t c_pkg_end = pkg_count * pkg_size;
|
|
6863
6979
|
#else
|
|
@@ -7280,10 +7396,17 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7280
7396
|
float sf1 = (float)ne1/src0->ne[1];
|
|
7281
7397
|
float sf2 = (float)ne2/src0->ne[2];
|
|
7282
7398
|
float sf3 = (float)ne3/src0->ne[3];
|
|
7399
|
+
float pixel_offset = 0.5f;
|
|
7283
7400
|
|
|
7284
7401
|
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
|
7285
7402
|
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
7286
7403
|
|
|
7404
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7405
|
+
pixel_offset = 0.0f;
|
|
7406
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
7407
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
7408
|
+
}
|
|
7409
|
+
|
|
7287
7410
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
7288
7411
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7289
7412
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7302,14 +7425,66 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7302
7425
|
}
|
|
7303
7426
|
}
|
|
7304
7427
|
}
|
|
7305
|
-
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7306
|
-
|
|
7307
|
-
|
|
7308
|
-
|
|
7309
|
-
|
|
7310
|
-
|
|
7311
|
-
|
|
7428
|
+
} else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
|
|
7429
|
+
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
|
|
7430
|
+
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
|
|
7431
|
+
auto triangle_filter = [](float x) -> float {
|
|
7432
|
+
return std::max(1.0f - fabsf(x), 0.0f);
|
|
7433
|
+
};
|
|
7434
|
+
|
|
7435
|
+
// support and invscale, minimum 1 pixel for bilinear
|
|
7436
|
+
const float support1 = std::max(1.0f, 1.0f / sf1);
|
|
7437
|
+
const float invscale1 = 1.0f / support1;
|
|
7438
|
+
const float support0 = std::max(1.0f, 1.0f / sf0);
|
|
7439
|
+
const float invscale0 = 1.0f / support0;
|
|
7440
|
+
|
|
7441
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7442
|
+
const int64_t i03 = i3 / sf3;
|
|
7443
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7444
|
+
const int64_t i02 = i2 / sf2;
|
|
7445
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7446
|
+
const float y = ((float) i1 + pixel_offset) / sf1;
|
|
7447
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7448
|
+
const float x = ((float) i0 + pixel_offset) / sf0;
|
|
7449
|
+
|
|
7450
|
+
// the range of source pixels that contribute
|
|
7451
|
+
const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
|
|
7452
|
+
const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
|
|
7453
|
+
const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
|
|
7454
|
+
const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
|
|
7455
|
+
|
|
7456
|
+
// bilinear filter with antialiasing
|
|
7457
|
+
float val = 0.0f;
|
|
7458
|
+
float total_weight = 0.0f;
|
|
7459
|
+
|
|
7460
|
+
for (int64_t sy = y_min; sy < y_max; sy++) {
|
|
7461
|
+
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
|
|
7462
|
+
|
|
7463
|
+
for (int64_t sx = x_min; sx < x_max; sx++) {
|
|
7464
|
+
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
|
|
7465
|
+
const float weight = weight_x * weight_y;
|
|
7466
|
+
|
|
7467
|
+
if (weight <= 0.0f) {
|
|
7468
|
+
continue;
|
|
7469
|
+
}
|
|
7470
|
+
|
|
7471
|
+
const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
|
|
7472
|
+
val += pixel * weight;
|
|
7473
|
+
total_weight += weight;
|
|
7474
|
+
}
|
|
7475
|
+
}
|
|
7312
7476
|
|
|
7477
|
+
if (total_weight > 0.0f) {
|
|
7478
|
+
val /= total_weight;
|
|
7479
|
+
}
|
|
7480
|
+
|
|
7481
|
+
float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7482
|
+
*dst_ptr = val;
|
|
7483
|
+
}
|
|
7484
|
+
}
|
|
7485
|
+
}
|
|
7486
|
+
}
|
|
7487
|
+
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7313
7488
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7314
7489
|
const int64_t i03 = i3 / sf3;
|
|
7315
7490
|
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
@@ -7344,6 +7519,51 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7344
7519
|
|
|
7345
7520
|
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
|
|
7346
7521
|
|
|
7522
|
+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7523
|
+
*y_dst = val;
|
|
7524
|
+
}
|
|
7525
|
+
}
|
|
7526
|
+
}
|
|
7527
|
+
}
|
|
7528
|
+
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
|
7529
|
+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
|
7530
|
+
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
|
7531
|
+
auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
|
|
7532
|
+
auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
|
|
7533
|
+
auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
|
|
7534
|
+
const float w0 = weight2(x + 1);
|
|
7535
|
+
const float w1 = weight1(x + 0);
|
|
7536
|
+
const float w2 = weight1(1 - x);
|
|
7537
|
+
const float w3 = weight2(2 - x);
|
|
7538
|
+
return p0*w0 + p1*w1 + p2*w2 + p3*w3;
|
|
7539
|
+
};
|
|
7540
|
+
|
|
7541
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7542
|
+
const int64_t i03 = i3 / sf3;
|
|
7543
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7544
|
+
const int64_t i02 = i2 / sf2;
|
|
7545
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7546
|
+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
|
|
7547
|
+
const int64_t y0 = (int64_t)floorf(y);
|
|
7548
|
+
const float dy = y - (float)y0;
|
|
7549
|
+
|
|
7550
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7551
|
+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
|
|
7552
|
+
const int64_t x0 = (int64_t)floorf(x);
|
|
7553
|
+
const float dx = x - (float)x0;
|
|
7554
|
+
|
|
7555
|
+
auto p = [=](int64_t x_off, int64_t y_off) -> float {
|
|
7556
|
+
int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
|
|
7557
|
+
int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
|
|
7558
|
+
return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
7559
|
+
};
|
|
7560
|
+
|
|
7561
|
+
const float val = bicubic(
|
|
7562
|
+
bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
|
|
7563
|
+
bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
|
|
7564
|
+
bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
|
|
7565
|
+
bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
|
|
7566
|
+
|
|
7347
7567
|
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7348
7568
|
*y_dst = val;
|
|
7349
7569
|
}
|
|
@@ -7376,6 +7596,7 @@ void ggml_compute_forward_upscale(
|
|
|
7376
7596
|
|
|
7377
7597
|
// ggml_compute_forward_pad
|
|
7378
7598
|
|
|
7599
|
+
template<bool circular_t>
|
|
7379
7600
|
static void ggml_compute_forward_pad_f32(
|
|
7380
7601
|
const ggml_compute_params * params,
|
|
7381
7602
|
ggml_tensor * dst) {
|
|
@@ -7391,6 +7612,14 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7391
7612
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
7392
7613
|
|
|
7393
7614
|
float * dst_ptr = (float *) dst->data;
|
|
7615
|
+
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
|
|
7616
|
+
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
|
|
7617
|
+
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
|
|
7618
|
+
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
|
|
7619
|
+
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
|
|
7620
|
+
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
|
|
7621
|
+
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
|
7622
|
+
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
|
7394
7623
|
|
|
7395
7624
|
// TODO: optimize
|
|
7396
7625
|
|
|
@@ -7398,14 +7627,34 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7398
7627
|
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
|
7399
7628
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
|
7400
7629
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
7401
|
-
|
|
7402
|
-
|
|
7403
|
-
|
|
7404
|
-
|
|
7405
|
-
|
|
7630
|
+
// circular means wrap around on a torus, so x and y loop around
|
|
7631
|
+
if constexpr (circular_t) {
|
|
7632
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7633
|
+
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
|
|
7634
|
+
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
|
|
7635
|
+
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
|
|
7636
|
+
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
|
|
7637
|
+
|
|
7638
|
+
const int64_t src_idx =
|
|
7639
|
+
src_i3*nb03 +
|
|
7640
|
+
src_i2*nb02 +
|
|
7641
|
+
src_i1*nb01 +
|
|
7642
|
+
src_i0*nb00;
|
|
7643
|
+
|
|
7644
|
+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7406
7645
|
dst_ptr[dst_idx] = *src_ptr;
|
|
7407
7646
|
} else {
|
|
7408
|
-
|
|
7647
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
7648
|
+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
|
7649
|
+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
|
7650
|
+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
|
7651
|
+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
|
7652
|
+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
|
7653
|
+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
7654
|
+
dst_ptr[dst_idx] = *src_ptr;
|
|
7655
|
+
} else {
|
|
7656
|
+
dst_ptr[dst_idx] = 0;
|
|
7657
|
+
}
|
|
7409
7658
|
}
|
|
7410
7659
|
}
|
|
7411
7660
|
}
|
|
@@ -7413,16 +7662,20 @@ static void ggml_compute_forward_pad_f32(
|
|
|
7413
7662
|
}
|
|
7414
7663
|
}
|
|
7415
7664
|
|
|
7665
|
+
|
|
7416
7666
|
void ggml_compute_forward_pad(
|
|
7417
7667
|
const ggml_compute_params * params,
|
|
7418
7668
|
ggml_tensor * dst) {
|
|
7419
|
-
|
|
7420
7669
|
const ggml_tensor * src0 = dst->src[0];
|
|
7421
|
-
|
|
7670
|
+
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
|
|
7422
7671
|
switch (src0->type) {
|
|
7423
7672
|
case GGML_TYPE_F32:
|
|
7424
7673
|
{
|
|
7425
|
-
|
|
7674
|
+
if (circular) {
|
|
7675
|
+
ggml_compute_forward_pad_f32<true>(params, dst);
|
|
7676
|
+
} else {
|
|
7677
|
+
ggml_compute_forward_pad_f32<false>(params, dst);
|
|
7678
|
+
}
|
|
7426
7679
|
} break;
|
|
7427
7680
|
default:
|
|
7428
7681
|
{
|
|
@@ -7601,7 +7854,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
|
|
|
7601
7854
|
embed_data[j + half] = sinf(arg);
|
|
7602
7855
|
}
|
|
7603
7856
|
if (dim % 2 != 0 && ith == 0) {
|
|
7604
|
-
embed_data[
|
|
7857
|
+
embed_data[2 * half] = 0.f;
|
|
7605
7858
|
}
|
|
7606
7859
|
}
|
|
7607
7860
|
}
|
|
@@ -7615,7 +7868,80 @@ void ggml_compute_forward_timestep_embedding(
|
|
|
7615
7868
|
switch (src0->type) {
|
|
7616
7869
|
case GGML_TYPE_F32:
|
|
7617
7870
|
{
|
|
7618
|
-
ggml_compute_forward_timestep_embedding_f32(params, dst);
|
|
7871
|
+
ggml_compute_forward_timestep_embedding_f32(params, dst);
|
|
7872
|
+
} break;
|
|
7873
|
+
default:
|
|
7874
|
+
{
|
|
7875
|
+
GGML_ABORT("fatal error");
|
|
7876
|
+
}
|
|
7877
|
+
}
|
|
7878
|
+
}
|
|
7879
|
+
|
|
7880
|
+
// ggml_compute_forward_argsort
|
|
7881
|
+
|
|
7882
|
+
template<enum ggml_sort_order order>
|
|
7883
|
+
struct cmp_argsort {
|
|
7884
|
+
const float * data;
|
|
7885
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7886
|
+
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
|
7887
|
+
return data[a] < data[b];
|
|
7888
|
+
} else {
|
|
7889
|
+
return data[a] > data[b];
|
|
7890
|
+
}
|
|
7891
|
+
}
|
|
7892
|
+
};
|
|
7893
|
+
|
|
7894
|
+
static void ggml_compute_forward_argsort_f32(
|
|
7895
|
+
const ggml_compute_params * params,
|
|
7896
|
+
ggml_tensor * dst) {
|
|
7897
|
+
|
|
7898
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7899
|
+
|
|
7900
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
7901
|
+
|
|
7902
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
7903
|
+
|
|
7904
|
+
const int ith = params->ith;
|
|
7905
|
+
const int nth = params->nth;
|
|
7906
|
+
|
|
7907
|
+
const int64_t nr = ggml_nrows(src0);
|
|
7908
|
+
|
|
7909
|
+
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
|
7910
|
+
|
|
7911
|
+
for (int64_t i = ith; i < nr; i += nth) {
|
|
7912
|
+
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
7913
|
+
|
|
7914
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7915
|
+
|
|
7916
|
+
for (int64_t j = 0; j < ne0; j++) {
|
|
7917
|
+
dst_data[j] = j;
|
|
7918
|
+
}
|
|
7919
|
+
|
|
7920
|
+
switch (order) {
|
|
7921
|
+
case GGML_SORT_ORDER_ASC:
|
|
7922
|
+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
|
|
7923
|
+
break;
|
|
7924
|
+
|
|
7925
|
+
case GGML_SORT_ORDER_DESC:
|
|
7926
|
+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
|
|
7927
|
+
break;
|
|
7928
|
+
|
|
7929
|
+
default:
|
|
7930
|
+
GGML_ABORT("invalid sort order");
|
|
7931
|
+
}
|
|
7932
|
+
}
|
|
7933
|
+
}
|
|
7934
|
+
|
|
7935
|
+
void ggml_compute_forward_argsort(
|
|
7936
|
+
const ggml_compute_params * params,
|
|
7937
|
+
ggml_tensor * dst) {
|
|
7938
|
+
|
|
7939
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7940
|
+
|
|
7941
|
+
switch (src0->type) {
|
|
7942
|
+
case GGML_TYPE_F32:
|
|
7943
|
+
{
|
|
7944
|
+
ggml_compute_forward_argsort_f32(params, dst);
|
|
7619
7945
|
} break;
|
|
7620
7946
|
default:
|
|
7621
7947
|
{
|
|
@@ -7624,9 +7950,16 @@ void ggml_compute_forward_timestep_embedding(
|
|
|
7624
7950
|
}
|
|
7625
7951
|
}
|
|
7626
7952
|
|
|
7627
|
-
//
|
|
7953
|
+
// ggml_compute_forward_top_k
|
|
7628
7954
|
|
|
7629
|
-
|
|
7955
|
+
struct cmp_top_k {
|
|
7956
|
+
const float * data;
|
|
7957
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7958
|
+
return data[a] > data[b];
|
|
7959
|
+
}
|
|
7960
|
+
};
|
|
7961
|
+
|
|
7962
|
+
static void ggml_compute_forward_top_k_f32(
|
|
7630
7963
|
const ggml_compute_params * params,
|
|
7631
7964
|
ggml_tensor * dst) {
|
|
7632
7965
|
|
|
@@ -7641,31 +7974,31 @@ static void ggml_compute_forward_argsort_f32(
|
|
|
7641
7974
|
|
|
7642
7975
|
const int64_t nr = ggml_nrows(src0);
|
|
7643
7976
|
|
|
7644
|
-
|
|
7977
|
+
const int top_k = ne0;
|
|
7978
|
+
|
|
7979
|
+
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
7645
7980
|
|
|
7646
7981
|
for (int64_t i = ith; i < nr; i += nth) {
|
|
7647
|
-
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7648
7982
|
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
7649
7983
|
|
|
7650
|
-
for (int64_t j = 0; j <
|
|
7651
|
-
|
|
7984
|
+
for (int64_t j = 0; j < ne00; j++) {
|
|
7985
|
+
tmp[j] = j;
|
|
7652
7986
|
}
|
|
7653
7987
|
|
|
7654
|
-
|
|
7655
|
-
|
|
7656
|
-
|
|
7657
|
-
|
|
7658
|
-
|
|
7659
|
-
|
|
7660
|
-
|
|
7661
|
-
|
|
7662
|
-
|
|
7663
|
-
}
|
|
7988
|
+
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
|
|
7989
|
+
|
|
7990
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7991
|
+
|
|
7992
|
+
std::copy(tmp, tmp + top_k, dst_data);
|
|
7993
|
+
|
|
7994
|
+
// emphasize that the order is not important
|
|
7995
|
+
if (top_k > 1) {
|
|
7996
|
+
std::swap(dst_data[0], dst_data[1]);
|
|
7664
7997
|
}
|
|
7665
7998
|
}
|
|
7666
7999
|
}
|
|
7667
8000
|
|
|
7668
|
-
void
|
|
8001
|
+
void ggml_compute_forward_top_k(
|
|
7669
8002
|
const ggml_compute_params * params,
|
|
7670
8003
|
ggml_tensor * dst) {
|
|
7671
8004
|
|
|
@@ -7674,7 +8007,7 @@ void ggml_compute_forward_argsort(
|
|
|
7674
8007
|
switch (src0->type) {
|
|
7675
8008
|
case GGML_TYPE_F32:
|
|
7676
8009
|
{
|
|
7677
|
-
|
|
8010
|
+
ggml_compute_forward_top_k_f32(params, dst);
|
|
7678
8011
|
} break;
|
|
7679
8012
|
default:
|
|
7680
8013
|
{
|
|
@@ -7685,13 +8018,15 @@ void ggml_compute_forward_argsort(
|
|
|
7685
8018
|
|
|
7686
8019
|
// ggml_compute_forward_flash_attn_ext
|
|
7687
8020
|
|
|
7688
|
-
static void
|
|
8021
|
+
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
7689
8022
|
const ggml_compute_params * params,
|
|
7690
|
-
|
|
7691
|
-
|
|
7692
|
-
|
|
7693
|
-
|
|
7694
|
-
|
|
8023
|
+
ggml_tensor * dst,
|
|
8024
|
+
int ir0, int ir1) {
|
|
8025
|
+
const ggml_tensor * q = dst->src[0];
|
|
8026
|
+
const ggml_tensor * k = dst->src[1];
|
|
8027
|
+
const ggml_tensor * v = dst->src[2];
|
|
8028
|
+
const ggml_tensor * mask = dst->src[3];
|
|
8029
|
+
const ggml_tensor * sinks = dst->src[4];
|
|
7695
8030
|
|
|
7696
8031
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
7697
8032
|
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
@@ -7702,9 +8037,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7702
8037
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
7703
8038
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
7704
8039
|
|
|
7705
|
-
const int ith = params->ith;
|
|
7706
|
-
const int nth = params->nth;
|
|
7707
|
-
|
|
7708
8040
|
const int64_t DK = nek0;
|
|
7709
8041
|
const int64_t DV = nev0;
|
|
7710
8042
|
const int64_t N = neq1;
|
|
@@ -7738,16 +8070,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7738
8070
|
|
|
7739
8071
|
// parallelize by q rows using ggml_vec_dot_f32
|
|
7740
8072
|
|
|
7741
|
-
// total rows in q
|
|
7742
|
-
const int nr = neq1*neq2*neq3;
|
|
7743
|
-
|
|
7744
|
-
// rows per thread
|
|
7745
|
-
const int dr = (nr + nth - 1)/nth;
|
|
7746
|
-
|
|
7747
|
-
// row range for this thread
|
|
7748
|
-
const int ir0 = dr*ith;
|
|
7749
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
7750
|
-
|
|
7751
8073
|
float scale = 1.0f;
|
|
7752
8074
|
float max_bias = 0.0f;
|
|
7753
8075
|
float logit_softcap = 0.0f;
|
|
@@ -7766,7 +8088,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7766
8088
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
7767
8089
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
7768
8090
|
|
|
7769
|
-
ggml_type
|
|
8091
|
+
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
|
7770
8092
|
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
|
7771
8093
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
|
7772
8094
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
|
@@ -7774,6 +8096,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7774
8096
|
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
7775
8097
|
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
7776
8098
|
|
|
8099
|
+
int ith = params->ith;
|
|
8100
|
+
|
|
7777
8101
|
// loop over n_batch and n_head
|
|
7778
8102
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
7779
8103
|
// q indices
|
|
@@ -7798,7 +8122,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7798
8122
|
memset(VKQ32, 0, DV*sizeof(float));
|
|
7799
8123
|
}
|
|
7800
8124
|
|
|
7801
|
-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
|
8125
|
+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
|
7802
8126
|
|
|
7803
8127
|
// k indices
|
|
7804
8128
|
const int ik3 = iq3 / rk3;
|
|
@@ -7887,8 +8211,25 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7887
8211
|
}
|
|
7888
8212
|
}
|
|
7889
8213
|
|
|
8214
|
+
// sinks
|
|
8215
|
+
if (sinks) {
|
|
8216
|
+
const float s = ((float *)((char *) sinks->data))[h];
|
|
8217
|
+
|
|
8218
|
+
float ms = 1.0f;
|
|
8219
|
+
float vs = 1.0f;
|
|
8220
|
+
|
|
8221
|
+
if (s > M) {
|
|
8222
|
+
ms = expf(M - s);
|
|
8223
|
+
ggml_vec_scale_f32(DV, VKQ32, ms);
|
|
8224
|
+
} else {
|
|
8225
|
+
vs = expf(s - M);
|
|
8226
|
+
}
|
|
8227
|
+
|
|
8228
|
+
S = S*ms + vs;
|
|
8229
|
+
}
|
|
8230
|
+
|
|
7890
8231
|
// V /= S
|
|
7891
|
-
const float S_inv = 1.0f/S;
|
|
8232
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
7892
8233
|
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
7893
8234
|
|
|
7894
8235
|
// dst indices
|
|
@@ -7904,19 +8245,100 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7904
8245
|
}
|
|
7905
8246
|
}
|
|
7906
8247
|
|
|
8248
|
+
static void ggml_compute_forward_flash_attn_ext_f16(
|
|
8249
|
+
const ggml_compute_params * params,
|
|
8250
|
+
ggml_tensor * dst) {
|
|
8251
|
+
|
|
8252
|
+
const ggml_tensor * q = dst->src[0];
|
|
8253
|
+
const ggml_tensor * k = dst->src[1];
|
|
8254
|
+
const ggml_tensor * v = dst->src[2];
|
|
8255
|
+
|
|
8256
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8257
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8258
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8259
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8260
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8261
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8262
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8263
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8264
|
+
|
|
8265
|
+
const int64_t DK = nek0;
|
|
8266
|
+
const int64_t DV = nev0;
|
|
8267
|
+
const int64_t N = neq1;
|
|
8268
|
+
|
|
8269
|
+
GGML_ASSERT(ne0 == DV);
|
|
8270
|
+
GGML_ASSERT(ne2 == N);
|
|
8271
|
+
|
|
8272
|
+
// input tensor rows must be contiguous
|
|
8273
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
8274
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
8275
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
8276
|
+
|
|
8277
|
+
GGML_ASSERT(neq0 == DK);
|
|
8278
|
+
GGML_ASSERT(nek0 == DK);
|
|
8279
|
+
GGML_ASSERT(nev0 == DV);
|
|
8280
|
+
|
|
8281
|
+
GGML_ASSERT(neq1 == N);
|
|
8282
|
+
|
|
8283
|
+
// dst cannot be transposed or permuted
|
|
8284
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8285
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
8286
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
8287
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
8288
|
+
|
|
8289
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
|
8290
|
+
|
|
8291
|
+
// total rows in q
|
|
8292
|
+
const int64_t nr = neq1*neq2*neq3;
|
|
8293
|
+
|
|
8294
|
+
// rows per thread
|
|
8295
|
+
const int ith = params->ith;
|
|
8296
|
+
const int nth = params->nth;
|
|
8297
|
+
|
|
8298
|
+
// disable for NUMA
|
|
8299
|
+
const bool disable_chunking = ggml_is_numa();
|
|
8300
|
+
|
|
8301
|
+
// 4x chunks per thread
|
|
8302
|
+
int nth_scaled = nth * 4;
|
|
8303
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8304
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8305
|
+
|
|
8306
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
8307
|
+
nchunk = nth;
|
|
8308
|
+
}
|
|
8309
|
+
|
|
8310
|
+
if (ith == 0) {
|
|
8311
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
8312
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
8313
|
+
}
|
|
8314
|
+
|
|
8315
|
+
ggml_barrier(params->threadpool);
|
|
8316
|
+
|
|
8317
|
+
// The number of elements in each chunk
|
|
8318
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
8319
|
+
|
|
8320
|
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
8321
|
+
int current_chunk = ith;
|
|
8322
|
+
|
|
8323
|
+
while (current_chunk < nchunk) {
|
|
8324
|
+
const int64_t ir0 = dr * current_chunk;
|
|
8325
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
8326
|
+
|
|
8327
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
|
8328
|
+
|
|
8329
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
8330
|
+
}
|
|
8331
|
+
}
|
|
8332
|
+
|
|
7907
8333
|
void ggml_compute_forward_flash_attn_ext(
|
|
7908
8334
|
const ggml_compute_params * params,
|
|
7909
|
-
const ggml_tensor * q,
|
|
7910
|
-
const ggml_tensor * k,
|
|
7911
|
-
const ggml_tensor * v,
|
|
7912
|
-
const ggml_tensor * mask,
|
|
7913
8335
|
ggml_tensor * dst) {
|
|
7914
8336
|
switch (dst->op_params[3]) {
|
|
7915
8337
|
case GGML_PREC_DEFAULT:
|
|
7916
8338
|
case GGML_PREC_F32:
|
|
7917
8339
|
{
|
|
7918
8340
|
// uses F32 accumulators
|
|
7919
|
-
ggml_compute_forward_flash_attn_ext_f16(params,
|
|
8341
|
+
ggml_compute_forward_flash_attn_ext_f16(params, dst);
|
|
7920
8342
|
} break;
|
|
7921
8343
|
default:
|
|
7922
8344
|
{
|
|
@@ -8336,120 +8758,214 @@ void ggml_compute_forward_ssm_conv(
|
|
|
8336
8758
|
static void ggml_compute_forward_ssm_scan_f32(
|
|
8337
8759
|
const ggml_compute_params * params,
|
|
8338
8760
|
ggml_tensor * dst) {
|
|
8339
|
-
const ggml_tensor * src0 = dst->src[0]; // s
|
|
8340
|
-
const ggml_tensor * src1 = dst->src[1]; // x
|
|
8341
|
-
const ggml_tensor * src2 = dst->src[2]; // dt
|
|
8342
|
-
const ggml_tensor * src3 = dst->src[3]; // A
|
|
8343
|
-
const ggml_tensor * src4 = dst->src[4]; // B
|
|
8344
|
-
const ggml_tensor * src5 = dst->src[5]; // C
|
|
8761
|
+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
|
8762
|
+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
|
8763
|
+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
|
8764
|
+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
|
8765
|
+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8766
|
+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8767
|
+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
|
8345
8768
|
|
|
8346
8769
|
const int ith = params->ith;
|
|
8347
8770
|
const int nth = params->nth;
|
|
8348
8771
|
|
|
8349
|
-
const int64_t nc
|
|
8350
|
-
const int64_t nr
|
|
8351
|
-
const int64_t
|
|
8352
|
-
const int64_t
|
|
8772
|
+
const int64_t nc = src0->ne[0]; // d_state
|
|
8773
|
+
const int64_t nr = src0->ne[1]; // dim
|
|
8774
|
+
const int64_t nh = src1->ne[1]; // n_head
|
|
8775
|
+
const int64_t ng = src4->ne[1];
|
|
8776
|
+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
|
8777
|
+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
|
8778
|
+
|
|
8779
|
+
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
|
8780
|
+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
|
8353
8781
|
|
|
8354
|
-
GGML_ASSERT(ggml_nelements(src1) +
|
|
8782
|
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
|
8355
8783
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
8356
8784
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
8357
8785
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
8358
8786
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
|
8359
8787
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
8360
8788
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
8361
|
-
|
|
8362
|
-
GGML_ASSERT(
|
|
8363
|
-
// required for per-sequence offsets for states
|
|
8364
|
-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
|
8365
|
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
8366
|
-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
|
8789
|
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
8790
|
+
GGML_ASSERT(nh % ng == 0);
|
|
8367
8791
|
|
|
8368
|
-
//
|
|
8369
|
-
const int
|
|
8792
|
+
// heads per thread
|
|
8793
|
+
const int dh = (nh + nth - 1)/nth;
|
|
8370
8794
|
|
|
8371
|
-
//
|
|
8372
|
-
const int
|
|
8373
|
-
const int
|
|
8374
|
-
|
|
8795
|
+
// head range for this thread
|
|
8796
|
+
const int ih0 = dh*ith;
|
|
8797
|
+
const int ih1 = MIN(ih0 + dh, nh);
|
|
8798
|
+
|
|
8799
|
+
const int32_t * ids = (const int32_t *) src6->data;
|
|
8800
|
+
|
|
8801
|
+
for (int i3 = 0; i3 < ns; ++i3) {
|
|
8802
|
+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
|
8803
|
+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
|
8804
|
+
|
|
8805
|
+
for (int i2 = 0; i2 < nt; ++i2) {
|
|
8806
|
+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
|
8807
|
+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
|
8808
|
+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
|
8809
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
|
8810
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
|
8811
|
+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
|
8812
|
+
|
|
8813
|
+
if (src3->ne[0] == 1) {
|
|
8814
|
+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
|
8815
|
+
|
|
8816
|
+
// n_head
|
|
8817
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8818
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8819
|
+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
|
8820
|
+
const float dA = expf(dt_soft_plus * A[h]);
|
|
8821
|
+
const int g = h / (nh / ng); // repeat_interleave
|
|
8822
|
+
|
|
8823
|
+
// dim
|
|
8824
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8825
|
+
const int ii = i1 + h*nr;
|
|
8826
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8827
|
+
float sumf = 0.0f;
|
|
8828
|
+
#if defined(GGML_SIMD)
|
|
8829
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8830
|
+
const int ggml_f32_epr = svcntw();
|
|
8831
|
+
const int ggml_f32_step = 1 * ggml_f32_epr;
|
|
8832
|
+
|
|
8833
|
+
const int np = (nc & ~(ggml_f32_step - 1));
|
|
8834
|
+
|
|
8835
|
+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
|
8836
|
+
|
|
8837
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8838
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8839
|
+
|
|
8840
|
+
for (int i = 0; i < np; i += ggml_f32_step) {
|
|
8841
|
+
// TODO: maybe unroll more?
|
|
8842
|
+
for (int j = 0; j < 1; j++) {
|
|
8843
|
+
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
|
8844
|
+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
|
|
8845
|
+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
|
|
8846
|
+
|
|
8847
|
+
t0 = GGML_F32_VEC_MUL(t0, adA);
|
|
8848
|
+
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
|
8849
|
+
|
|
8850
|
+
t0 = GGML_F32_VEC_ADD(t0, t1);
|
|
8851
|
+
|
|
8852
|
+
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
|
8375
8853
|
|
|
8376
|
-
|
|
8377
|
-
|
|
8378
|
-
|
|
8379
|
-
|
|
8380
|
-
|
|
8381
|
-
|
|
8382
|
-
|
|
8383
|
-
|
|
8384
|
-
|
|
8385
|
-
|
|
8386
|
-
|
|
8387
|
-
|
|
8388
|
-
|
|
8389
|
-
|
|
8390
|
-
|
|
8391
|
-
|
|
8392
|
-
|
|
8393
|
-
|
|
8394
|
-
|
|
8395
|
-
|
|
8396
|
-
|
|
8397
|
-
|
|
8398
|
-
|
|
8399
|
-
|
|
8400
|
-
|
|
8401
|
-
|
|
8402
|
-
|
|
8403
|
-
|
|
8404
|
-
|
|
8405
|
-
|
|
8406
|
-
|
|
8407
|
-
|
|
8408
|
-
|
|
8409
|
-
|
|
8410
|
-
|
|
8411
|
-
|
|
8412
|
-
|
|
8854
|
+
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
|
8855
|
+
}
|
|
8856
|
+
}
|
|
8857
|
+
|
|
8858
|
+
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
|
8859
|
+
#elif defined(__riscv_v_intrinsic)
|
|
8860
|
+
// todo: RVV implementation
|
|
8861
|
+
const int np = 0;
|
|
8862
|
+
#else
|
|
8863
|
+
const int np = (nc & ~(GGML_F32_STEP - 1));
|
|
8864
|
+
|
|
8865
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
8866
|
+
|
|
8867
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8868
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8869
|
+
|
|
8870
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
|
8871
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
8872
|
+
GGML_F32_VEC az[GGML_F32_ARR];
|
|
8873
|
+
|
|
8874
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
|
8875
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
|
8876
|
+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
|
8877
|
+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
|
|
8878
|
+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
|
|
8879
|
+
|
|
8880
|
+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
|
8881
|
+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
|
8882
|
+
|
|
8883
|
+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
|
8884
|
+
|
|
8885
|
+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
|
8886
|
+
|
|
8887
|
+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
|
8888
|
+
}
|
|
8889
|
+
}
|
|
8890
|
+
|
|
8891
|
+
// reduce sum0..sum3 to sum0
|
|
8892
|
+
GGML_F32_VEC_REDUCE(sumf, sum);
|
|
8893
|
+
#endif
|
|
8894
|
+
#else
|
|
8895
|
+
const int np = 0;
|
|
8896
|
+
#endif
|
|
8897
|
+
// d_state
|
|
8898
|
+
for (int i0 = np; i0 < nc; ++i0) {
|
|
8899
|
+
const int i = i0 + ii*nc;
|
|
8900
|
+
const int ig = i0 + g*nc;
|
|
8901
|
+
// state = prev_state * dA + dB * x
|
|
8902
|
+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
|
8903
|
+
// y = rowwise_dotprod(state, C)
|
|
8904
|
+
sumf += state * C[ig];
|
|
8905
|
+
s[i] = state;
|
|
8906
|
+
}
|
|
8907
|
+
y[ii] = sumf;
|
|
8413
8908
|
}
|
|
8414
|
-
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8415
8909
|
}
|
|
8416
|
-
}
|
|
8417
|
-
|
|
8418
|
-
|
|
8419
|
-
|
|
8420
|
-
|
|
8421
|
-
|
|
8422
|
-
|
|
8423
|
-
|
|
8424
|
-
|
|
8425
|
-
|
|
8426
|
-
|
|
8427
|
-
|
|
8428
|
-
|
|
8429
|
-
|
|
8430
|
-
|
|
8431
|
-
|
|
8432
|
-
|
|
8433
|
-
|
|
8434
|
-
|
|
8435
|
-
|
|
8436
|
-
|
|
8437
|
-
|
|
8438
|
-
|
|
8439
|
-
|
|
8440
|
-
|
|
8441
|
-
|
|
8442
|
-
|
|
8443
|
-
|
|
8444
|
-
|
|
8445
|
-
|
|
8446
|
-
|
|
8910
|
+
} else {
|
|
8911
|
+
// Mamba-1 has an element-wise decay factor for the states
|
|
8912
|
+
|
|
8913
|
+
// n_head
|
|
8914
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8915
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8916
|
+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
|
8917
|
+
const int g = h / (nh / ng); // repeat_interleave
|
|
8918
|
+
|
|
8919
|
+
// dim
|
|
8920
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8921
|
+
const int ii = i1 + h*nr;
|
|
8922
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8923
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8924
|
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
|
8925
|
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
|
8926
|
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
|
8927
|
+
|
|
8928
|
+
// d_state
|
|
8929
|
+
// TODO: what happens when (d_state % svcntw()) != 0?
|
|
8930
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
8931
|
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
|
8932
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
|
|
8933
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
|
|
8934
|
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
|
8935
|
+
|
|
8936
|
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
8937
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
8938
|
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
|
8939
|
+
|
|
8940
|
+
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
|
8941
|
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
8942
|
+
|
|
8943
|
+
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
|
8944
|
+
}
|
|
8945
|
+
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8946
|
+
#else
|
|
8947
|
+
float sumf = 0.0f;
|
|
8948
|
+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
|
8949
|
+
// and also because expf is used within the loop.
|
|
8950
|
+
// d_state
|
|
8951
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
|
8952
|
+
const int i = i0 + ii*nc;
|
|
8953
|
+
const int ig = i0 + g*nc;
|
|
8954
|
+
// state = prev_state * dA + dB * x
|
|
8955
|
+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
|
8956
|
+
// y = rowwise_dotprod(state, C)
|
|
8957
|
+
sumf += state * C[ig];
|
|
8958
|
+
s[i] = state;
|
|
8959
|
+
}
|
|
8960
|
+
y[ii] = sumf;
|
|
8961
|
+
#endif
|
|
8447
8962
|
}
|
|
8448
|
-
y[i1] = sumf;
|
|
8449
8963
|
}
|
|
8450
8964
|
}
|
|
8965
|
+
// use the output as the source when it's not the first token-wise iteration
|
|
8966
|
+
s0 = s;
|
|
8451
8967
|
}
|
|
8452
|
-
|
|
8968
|
+
}
|
|
8453
8969
|
}
|
|
8454
8970
|
|
|
8455
8971
|
void ggml_compute_forward_ssm_scan(
|
|
@@ -8660,6 +9176,34 @@ void ggml_compute_forward_unary(
|
|
|
8660
9176
|
{
|
|
8661
9177
|
ggml_compute_forward_exp(params, dst);
|
|
8662
9178
|
} break;
|
|
9179
|
+
case GGML_UNARY_OP_FLOOR:
|
|
9180
|
+
{
|
|
9181
|
+
ggml_compute_forward_floor(params, dst);
|
|
9182
|
+
} break;
|
|
9183
|
+
case GGML_UNARY_OP_CEIL:
|
|
9184
|
+
{
|
|
9185
|
+
ggml_compute_forward_ceil(params, dst);
|
|
9186
|
+
} break;
|
|
9187
|
+
case GGML_UNARY_OP_ROUND:
|
|
9188
|
+
{
|
|
9189
|
+
ggml_compute_forward_round(params, dst);
|
|
9190
|
+
} break;
|
|
9191
|
+
case GGML_UNARY_OP_TRUNC:
|
|
9192
|
+
{
|
|
9193
|
+
ggml_compute_forward_trunc(params, dst);
|
|
9194
|
+
} break;
|
|
9195
|
+
case GGML_UNARY_OP_XIELU:
|
|
9196
|
+
{
|
|
9197
|
+
ggml_compute_forward_xielu(params, dst);
|
|
9198
|
+
} break;
|
|
9199
|
+
case GGML_UNARY_OP_EXPM1:
|
|
9200
|
+
{
|
|
9201
|
+
ggml_compute_forward_expm1(params, dst);
|
|
9202
|
+
} break;
|
|
9203
|
+
case GGML_UNARY_OP_SOFTPLUS:
|
|
9204
|
+
{
|
|
9205
|
+
ggml_compute_forward_softplus(params, dst);
|
|
9206
|
+
} break;
|
|
8663
9207
|
default:
|
|
8664
9208
|
{
|
|
8665
9209
|
GGML_ABORT("fatal error");
|
|
@@ -8688,6 +9232,18 @@ void ggml_compute_forward_glu(
|
|
|
8688
9232
|
{
|
|
8689
9233
|
ggml_compute_forward_swiglu(params, dst);
|
|
8690
9234
|
} break;
|
|
9235
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
9236
|
+
{
|
|
9237
|
+
ggml_compute_forward_swiglu_oai(params, dst);
|
|
9238
|
+
} break;
|
|
9239
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9240
|
+
{
|
|
9241
|
+
ggml_compute_forward_geglu_erf(params, dst);
|
|
9242
|
+
} break;
|
|
9243
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9244
|
+
{
|
|
9245
|
+
ggml_compute_forward_geglu_quick(params, dst);
|
|
9246
|
+
} break;
|
|
8691
9247
|
default:
|
|
8692
9248
|
{
|
|
8693
9249
|
GGML_ABORT("fatal error");
|
|
@@ -9244,6 +9800,76 @@ void ggml_compute_forward_gla(
|
|
|
9244
9800
|
}
|
|
9245
9801
|
}
|
|
9246
9802
|
|
|
9803
|
+
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
9804
|
+
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
|
9805
|
+
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
|
9806
|
+
|
|
9807
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
9808
|
+
|
|
9809
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
9810
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
9811
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
9812
|
+
|
|
9813
|
+
GGML_ASSERT(ne00 == ne01); // A must be square
|
|
9814
|
+
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
|
9815
|
+
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
|
9816
|
+
|
|
9817
|
+
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
|
9818
|
+
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
|
9819
|
+
|
|
9820
|
+
const int ith = params->ith;
|
|
9821
|
+
const int nth = params->nth;
|
|
9822
|
+
|
|
9823
|
+
const int64_t k = ne10; // number of RHS columns
|
|
9824
|
+
const int64_t n = ne11; // A is n×n
|
|
9825
|
+
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
|
9826
|
+
|
|
9827
|
+
// chunks per thread
|
|
9828
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
|
9829
|
+
|
|
9830
|
+
// chunk range for this thread
|
|
9831
|
+
const int64_t ir0 = dr*ith;
|
|
9832
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
9833
|
+
|
|
9834
|
+
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
|
9835
|
+
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
|
9836
|
+
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
|
9837
|
+
|
|
9838
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
9839
|
+
const int64_t i03 = ir/(ne02*k);
|
|
9840
|
+
const int64_t i02 = (ir - i03*ne02*k)/k;
|
|
9841
|
+
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
|
9842
|
+
|
|
9843
|
+
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
|
9844
|
+
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
|
9845
|
+
|
|
9846
|
+
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
|
9847
|
+
|
|
9848
|
+
for (int64_t i00 = 0; i00 < n; ++i00) {
|
|
9849
|
+
float sum = 0.0f;
|
|
9850
|
+
for (int64_t t = 0; t < i00; ++t) {
|
|
9851
|
+
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
|
9852
|
+
}
|
|
9853
|
+
|
|
9854
|
+
const float diag = A_batch[i00 * n + i00];
|
|
9855
|
+
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
|
|
9856
|
+
|
|
9857
|
+
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
|
9858
|
+
}
|
|
9859
|
+
}
|
|
9860
|
+
}
|
|
9861
|
+
|
|
9862
|
+
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
9863
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
9864
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
9865
|
+
|
|
9866
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
9867
|
+
ggml_compute_forward_solve_tri_f32(params, dst);
|
|
9868
|
+
} else {
|
|
9869
|
+
GGML_ABORT("fatal error");
|
|
9870
|
+
}
|
|
9871
|
+
}
|
|
9872
|
+
|
|
9247
9873
|
// ggml_compute_forward_rwkv_wkv7
|
|
9248
9874
|
|
|
9249
9875
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
@@ -9283,8 +9909,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
|
9283
9909
|
int64_t h_stride_2d = head_size * head_size;
|
|
9284
9910
|
|
|
9285
9911
|
#if defined(GGML_SIMD)
|
|
9286
|
-
#if defined(__ARM_FEATURE_SVE)
|
|
9287
|
-
// scalar Route to scalar implementation //TODO: Write SVE code
|
|
9912
|
+
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
|
|
9913
|
+
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
|
|
9288
9914
|
for (int64_t t = 0; t < T; t++) {
|
|
9289
9915
|
int64_t t_offset = t * t_stride;
|
|
9290
9916
|
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
@@ -9732,6 +10358,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
9732
10358
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
9733
10359
|
|
|
9734
10360
|
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
|
10361
|
+
|
|
9735
10362
|
const float alpha = adamw_params_ptr[0];
|
|
9736
10363
|
const float beta1 = adamw_params_ptr[1];
|
|
9737
10364
|
const float beta2 = adamw_params_ptr[2];
|
|
@@ -9739,7 +10366,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
9739
10366
|
const float wd = adamw_params_ptr[4];
|
|
9740
10367
|
const float beta1h = adamw_params_ptr[5];
|
|
9741
10368
|
const float beta2h = adamw_params_ptr[6];
|
|
9742
|
-
|
|
10369
|
+
const float keep = 1.f - alpha * wd;
|
|
9743
10370
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
9744
10371
|
const int64_t i03 = ir/(ne02*ne01);
|
|
9745
10372
|
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
@@ -9762,7 +10389,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
9762
10389
|
// The weight decay is applied independently of the Adam momenta m and v.
|
|
9763
10390
|
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
|
9764
10391
|
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
|
9765
|
-
w[i00] = w[i00]*
|
|
10392
|
+
w[i00] = w[i00] * keep - alpha * mh / vh;
|
|
9766
10393
|
}
|
|
9767
10394
|
}
|
|
9768
10395
|
}
|
|
@@ -9784,3 +10411,63 @@ void ggml_compute_forward_opt_step_adamw(
|
|
|
9784
10411
|
}
|
|
9785
10412
|
}
|
|
9786
10413
|
}
|
|
10414
|
+
|
|
10415
|
+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
10416
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10417
|
+
const ggml_tensor * src0_grad = dst->src[1];
|
|
10418
|
+
const ggml_tensor * sgd_params = dst->src[2];
|
|
10419
|
+
|
|
10420
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
|
10421
|
+
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
|
|
10422
|
+
|
|
10423
|
+
const int ith = params->ith;
|
|
10424
|
+
const int nth = params->nth;
|
|
10425
|
+
|
|
10426
|
+
const int nr = ggml_nrows(src0);
|
|
10427
|
+
|
|
10428
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
10429
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
|
10430
|
+
|
|
10431
|
+
// rows per thread
|
|
10432
|
+
const int dr = (nr + nth - 1) / nth;
|
|
10433
|
+
|
|
10434
|
+
// row range for this thread
|
|
10435
|
+
const int ir0 = dr * ith;
|
|
10436
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
10437
|
+
|
|
10438
|
+
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
|
10439
|
+
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
|
|
10440
|
+
const float alpha = sgd_params_ptr[0];
|
|
10441
|
+
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
|
10442
|
+
|
|
10443
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
10444
|
+
const int64_t i03 = ir / (ne02 * ne01);
|
|
10445
|
+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
|
10446
|
+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
|
10447
|
+
|
|
10448
|
+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
|
10449
|
+
|
|
10450
|
+
float * w = (float *) ((char *) src0->data + offset); // weight
|
|
10451
|
+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
|
10452
|
+
|
|
10453
|
+
for (int i00 = 0; i00 < ne00; ++i00) {
|
|
10454
|
+
w[i00] = w[i00] * keep - alpha * g[i00];
|
|
10455
|
+
}
|
|
10456
|
+
}
|
|
10457
|
+
}
|
|
10458
|
+
|
|
10459
|
+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
10460
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10461
|
+
|
|
10462
|
+
switch (src0->type) {
|
|
10463
|
+
case GGML_TYPE_F32:
|
|
10464
|
+
{
|
|
10465
|
+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
|
10466
|
+
}
|
|
10467
|
+
break;
|
|
10468
|
+
default:
|
|
10469
|
+
{
|
|
10470
|
+
GGML_ABORT("fatal error - sgd is F32 only");
|
|
10471
|
+
}
|
|
10472
|
+
}
|
|
10473
|
+
}
|