whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -17,10 +17,12 @@ struct ggml_metal_device_deleter {
|
|
|
17
17
|
|
|
18
18
|
typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr;
|
|
19
19
|
|
|
20
|
-
ggml_metal_device_t ggml_metal_device_get(
|
|
21
|
-
static ggml_metal_device_ptr
|
|
20
|
+
ggml_metal_device_t ggml_metal_device_get(int device) {
|
|
21
|
+
static std::vector<ggml_metal_device_ptr> devs;
|
|
22
22
|
|
|
23
|
-
|
|
23
|
+
devs.emplace_back(ggml_metal_device_init(device));
|
|
24
|
+
|
|
25
|
+
return devs.back().get();
|
|
24
26
|
}
|
|
25
27
|
|
|
26
28
|
struct ggml_metal_pipelines {
|
|
@@ -50,14 +52,14 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
|
|
|
50
52
|
}
|
|
51
53
|
|
|
52
54
|
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
|
|
53
|
-
if
|
|
55
|
+
if (ppls->data.find(name) == ppls->data.end()) {
|
|
54
56
|
return nullptr;
|
|
55
57
|
}
|
|
56
58
|
|
|
57
59
|
return ppls->data[name];
|
|
58
60
|
}
|
|
59
61
|
|
|
60
|
-
|
|
62
|
+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
|
|
61
63
|
char base[256];
|
|
62
64
|
char name[256];
|
|
63
65
|
|
|
@@ -71,34 +73,55 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t
|
|
|
71
73
|
snprintf(base, 256, "kernel_%s", op_str);
|
|
72
74
|
snprintf(name, 256, "%s", base);
|
|
73
75
|
|
|
74
|
-
|
|
75
|
-
if (res) {
|
|
76
|
-
|
|
76
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
77
|
+
if (!res.pipeline) {
|
|
78
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
77
79
|
}
|
|
78
80
|
|
|
79
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
80
|
-
|
|
81
81
|
return res;
|
|
82
82
|
}
|
|
83
83
|
|
|
84
|
-
|
|
84
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
|
|
85
85
|
char base[256];
|
|
86
86
|
char name[256];
|
|
87
87
|
|
|
88
88
|
snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
|
|
89
89
|
snprintf(name, 256, "%s", base);
|
|
90
90
|
|
|
91
|
-
|
|
92
|
-
if (res) {
|
|
93
|
-
|
|
91
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
92
|
+
if (!res.pipeline) {
|
|
93
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
94
94
|
}
|
|
95
95
|
|
|
96
|
-
res
|
|
96
|
+
return res;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
|
|
100
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
101
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
|
|
102
|
+
|
|
103
|
+
const char * pool_str = "undefined";
|
|
104
|
+
switch (op_pool) {
|
|
105
|
+
case GGML_OP_POOL_AVG: pool_str = "avg"; break;
|
|
106
|
+
case GGML_OP_POOL_MAX: pool_str = "max"; break;
|
|
107
|
+
default: GGML_ASSERT(false && "not implemented");
|
|
108
|
+
};
|
|
109
|
+
|
|
110
|
+
char base[256];
|
|
111
|
+
char name[256];
|
|
112
|
+
|
|
113
|
+
snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
|
|
114
|
+
snprintf(name, sizeof(name), "%s", base);
|
|
115
|
+
|
|
116
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
117
|
+
if (!res.pipeline) {
|
|
118
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
119
|
+
}
|
|
97
120
|
|
|
98
121
|
return res;
|
|
99
122
|
}
|
|
100
123
|
|
|
101
|
-
|
|
124
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
|
|
102
125
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
103
126
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
|
|
104
127
|
|
|
@@ -115,126 +138,147 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library
|
|
|
115
138
|
snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
|
|
116
139
|
snprintf(name, 256, "%s", base);
|
|
117
140
|
|
|
118
|
-
|
|
119
|
-
if (res) {
|
|
120
|
-
|
|
141
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
142
|
+
if (!res.pipeline) {
|
|
143
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
121
144
|
}
|
|
122
145
|
|
|
123
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
124
|
-
|
|
125
146
|
return res;
|
|
126
147
|
}
|
|
127
148
|
|
|
128
|
-
|
|
149
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
|
|
129
150
|
char base[256];
|
|
130
151
|
char name[256];
|
|
131
152
|
|
|
132
153
|
snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
|
|
133
154
|
snprintf(name, 256, "%s", base);
|
|
134
155
|
|
|
135
|
-
|
|
136
|
-
if (res) {
|
|
137
|
-
|
|
156
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
157
|
+
if (!res.pipeline) {
|
|
158
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
138
159
|
}
|
|
139
160
|
|
|
140
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
141
|
-
|
|
142
161
|
return res;
|
|
143
162
|
}
|
|
144
163
|
|
|
145
|
-
|
|
164
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
|
|
146
165
|
char base[256];
|
|
147
166
|
char name[256];
|
|
148
167
|
|
|
149
168
|
snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
|
|
150
169
|
snprintf(name, 256, "%s", base);
|
|
151
170
|
|
|
152
|
-
|
|
153
|
-
if (res) {
|
|
154
|
-
|
|
171
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
172
|
+
if (!res.pipeline) {
|
|
173
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
155
174
|
}
|
|
156
175
|
|
|
157
|
-
res
|
|
176
|
+
return res;
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
180
|
+
char base[256];
|
|
181
|
+
char name[256];
|
|
182
|
+
|
|
183
|
+
const int n = op->src[0]->ne[0];
|
|
184
|
+
|
|
185
|
+
snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
|
|
186
|
+
snprintf(name, 256, "%s_n=%d", base, n);
|
|
187
|
+
|
|
188
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
189
|
+
if (!res.pipeline) {
|
|
190
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
res.nsg = 1;
|
|
194
|
+
res.smem = 0;
|
|
158
195
|
|
|
159
196
|
return res;
|
|
160
197
|
}
|
|
161
198
|
|
|
162
|
-
|
|
199
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
|
|
163
200
|
char base[256];
|
|
164
201
|
char name[256];
|
|
165
202
|
|
|
166
203
|
snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
|
|
167
204
|
snprintf(name, 256, "%s", base);
|
|
168
205
|
|
|
169
|
-
|
|
170
|
-
if (res) {
|
|
171
|
-
|
|
206
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
207
|
+
if (!res.pipeline) {
|
|
208
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
172
209
|
}
|
|
173
210
|
|
|
174
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
175
|
-
|
|
176
211
|
return res;
|
|
177
212
|
}
|
|
178
213
|
|
|
179
|
-
|
|
180
|
-
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
181
|
-
|
|
214
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
182
215
|
char base[256];
|
|
183
216
|
char name[256];
|
|
184
217
|
|
|
185
|
-
|
|
218
|
+
int op_num = -1;
|
|
186
219
|
|
|
187
|
-
const char * op_str = "undefined";
|
|
188
220
|
switch (op->op) {
|
|
189
|
-
case GGML_OP_SCALE:
|
|
190
|
-
case
|
|
191
|
-
case
|
|
192
|
-
case
|
|
193
|
-
case
|
|
194
|
-
case
|
|
195
|
-
case
|
|
196
|
-
case
|
|
221
|
+
case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
|
|
222
|
+
case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
|
|
223
|
+
case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
|
|
224
|
+
case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
|
|
225
|
+
case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
|
|
226
|
+
case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
|
|
227
|
+
case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
|
|
228
|
+
case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
|
|
229
|
+
case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
|
|
197
230
|
case GGML_OP_UNARY:
|
|
198
231
|
switch (ggml_get_unary_op(op)) {
|
|
199
|
-
case GGML_UNARY_OP_TANH:
|
|
200
|
-
case GGML_UNARY_OP_RELU:
|
|
201
|
-
case GGML_UNARY_OP_SIGMOID:
|
|
202
|
-
case GGML_UNARY_OP_GELU:
|
|
203
|
-
case GGML_UNARY_OP_GELU_ERF:
|
|
204
|
-
case GGML_UNARY_OP_GELU_QUICK:
|
|
205
|
-
case GGML_UNARY_OP_SILU:
|
|
206
|
-
case GGML_UNARY_OP_ELU:
|
|
207
|
-
case GGML_UNARY_OP_NEG:
|
|
208
|
-
case GGML_UNARY_OP_ABS:
|
|
209
|
-
case GGML_UNARY_OP_SGN:
|
|
210
|
-
case GGML_UNARY_OP_STEP:
|
|
211
|
-
case GGML_UNARY_OP_HARDSWISH:
|
|
212
|
-
case GGML_UNARY_OP_HARDSIGMOID:
|
|
213
|
-
case GGML_UNARY_OP_EXP:
|
|
232
|
+
case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
|
|
233
|
+
case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
|
|
234
|
+
case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
|
|
235
|
+
case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
|
|
236
|
+
case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
|
|
237
|
+
case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
|
|
238
|
+
case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
|
|
239
|
+
case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
|
|
240
|
+
case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
|
|
241
|
+
case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
|
|
242
|
+
case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
|
|
243
|
+
case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
|
|
244
|
+
case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
|
|
245
|
+
case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
|
|
246
|
+
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
|
|
247
|
+
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
|
|
248
|
+
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
|
|
214
249
|
default: GGML_ABORT("fatal error");
|
|
215
250
|
} break;
|
|
216
251
|
default: GGML_ABORT("fatal error");
|
|
217
252
|
};
|
|
218
253
|
|
|
219
|
-
const char *
|
|
220
|
-
|
|
221
|
-
suffix = "_4";
|
|
222
|
-
}
|
|
254
|
+
const char * t0_str = ggml_type_name(op->src[0]->type);
|
|
255
|
+
const char * t_str = ggml_type_name(op->type);
|
|
223
256
|
|
|
224
|
-
|
|
225
|
-
|
|
257
|
+
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
|
258
|
+
const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
|
|
226
259
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
260
|
+
snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
|
261
|
+
snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
|
|
262
|
+
|
|
263
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
264
|
+
if (!res.pipeline) {
|
|
265
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
266
|
+
|
|
267
|
+
ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
|
|
268
|
+
ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
|
|
269
|
+
|
|
270
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
271
|
+
|
|
272
|
+
ggml_metal_cv_free(cv);
|
|
230
273
|
}
|
|
231
274
|
|
|
232
|
-
res
|
|
275
|
+
res.c4 = is_c4;
|
|
276
|
+
res.cnt = is_cnt;
|
|
233
277
|
|
|
234
278
|
return res;
|
|
235
279
|
}
|
|
236
280
|
|
|
237
|
-
|
|
281
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
238
282
|
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
|
239
283
|
|
|
240
284
|
char base[256];
|
|
@@ -258,48 +302,132 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
|
|
|
258
302
|
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
|
|
259
303
|
snprintf(name, 256, "%s", base);
|
|
260
304
|
|
|
261
|
-
|
|
262
|
-
if (res) {
|
|
263
|
-
|
|
305
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
306
|
+
if (!res.pipeline) {
|
|
307
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
264
308
|
}
|
|
265
309
|
|
|
266
|
-
res
|
|
310
|
+
return res;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
314
|
+
assert(op->op == GGML_OP_SUM);
|
|
315
|
+
|
|
316
|
+
char base[256];
|
|
317
|
+
char name[256];
|
|
318
|
+
|
|
319
|
+
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
|
|
320
|
+
snprintf(name, 256, "%s", base);
|
|
321
|
+
|
|
322
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
323
|
+
if (!res.pipeline) {
|
|
324
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
325
|
+
}
|
|
267
326
|
|
|
268
327
|
return res;
|
|
269
328
|
}
|
|
270
329
|
|
|
271
|
-
|
|
272
|
-
GGML_ASSERT(
|
|
330
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
331
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
273
332
|
|
|
274
333
|
char base[256];
|
|
275
334
|
char name[256];
|
|
276
335
|
|
|
277
|
-
|
|
336
|
+
int op_num = -1;
|
|
337
|
+
|
|
278
338
|
switch (op->op) {
|
|
279
|
-
case GGML_OP_SUM_ROWS:
|
|
280
|
-
|
|
281
|
-
case GGML_OP_MEAN:
|
|
282
|
-
op_str = "mean"; break;
|
|
339
|
+
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
|
|
340
|
+
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
|
|
283
341
|
default: GGML_ABORT("fatal error");
|
|
284
342
|
};
|
|
285
343
|
|
|
286
|
-
|
|
344
|
+
const char * t0_str = ggml_type_name(op->src[0]->type);
|
|
345
|
+
const char * t_str = ggml_type_name(op->type);
|
|
346
|
+
|
|
347
|
+
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
|
348
|
+
|
|
349
|
+
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
|
350
|
+
snprintf(name, 256, "%s_op=%d", base, op_num);
|
|
351
|
+
|
|
352
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
353
|
+
if (!res.pipeline) {
|
|
354
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
355
|
+
|
|
356
|
+
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
|
|
357
|
+
|
|
358
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
359
|
+
|
|
360
|
+
ggml_metal_cv_free(cv);
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
res.smem = 32*sizeof(float);
|
|
364
|
+
|
|
365
|
+
if (is_c4) {
|
|
366
|
+
res.smem *= 4;
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
res.c4 = is_c4;
|
|
370
|
+
|
|
371
|
+
return res;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
375
|
+
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
|
376
|
+
|
|
377
|
+
char base[256];
|
|
378
|
+
char name[256];
|
|
287
379
|
|
|
380
|
+
snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
|
|
288
381
|
snprintf(name, 256, "%s", base);
|
|
289
382
|
|
|
290
|
-
|
|
291
|
-
if (res) {
|
|
292
|
-
|
|
383
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
384
|
+
if (!res.pipeline) {
|
|
385
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
293
386
|
}
|
|
294
387
|
|
|
295
|
-
res
|
|
388
|
+
return res;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
392
|
+
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
|
393
|
+
|
|
394
|
+
char base[256];
|
|
395
|
+
char name[256];
|
|
396
|
+
|
|
397
|
+
snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
|
|
398
|
+
snprintf(name, 256, "%s", base);
|
|
399
|
+
|
|
400
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
401
|
+
if (!res.pipeline) {
|
|
402
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
return res;
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
409
|
+
GGML_ASSERT(op->op == GGML_OP_TRI);
|
|
410
|
+
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
|
411
|
+
|
|
412
|
+
char base[256];
|
|
413
|
+
char name[256];
|
|
414
|
+
|
|
415
|
+
const char * op_str = "tri";
|
|
416
|
+
const int ttype = op->op_params[0];
|
|
296
417
|
|
|
297
|
-
|
|
418
|
+
snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
|
|
419
|
+
|
|
420
|
+
snprintf(name, 256, "%s", base);
|
|
421
|
+
|
|
422
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
423
|
+
if (!res.pipeline) {
|
|
424
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
425
|
+
}
|
|
298
426
|
|
|
299
427
|
return res;
|
|
300
428
|
}
|
|
301
429
|
|
|
302
|
-
|
|
430
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
303
431
|
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
|
|
304
432
|
|
|
305
433
|
char base[256];
|
|
@@ -316,19 +444,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar
|
|
|
316
444
|
snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
|
|
317
445
|
snprintf(name, 256, "%s", base);
|
|
318
446
|
|
|
319
|
-
|
|
320
|
-
if (res) {
|
|
321
|
-
|
|
447
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
448
|
+
if (!res.pipeline) {
|
|
449
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
322
450
|
}
|
|
323
451
|
|
|
324
|
-
res =
|
|
325
|
-
|
|
326
|
-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
452
|
+
res.smem = 32*sizeof(float);
|
|
327
453
|
|
|
328
454
|
return res;
|
|
329
455
|
}
|
|
330
456
|
|
|
331
|
-
|
|
457
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
332
458
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
333
459
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
334
460
|
|
|
@@ -338,43 +464,82 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
|
|
|
338
464
|
char base[256];
|
|
339
465
|
char name[256];
|
|
340
466
|
|
|
341
|
-
|
|
342
|
-
snprintf(name, 256, "%s", base);
|
|
467
|
+
const char * suffix = "";
|
|
343
468
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
return res;
|
|
469
|
+
if (op->src[1]->ne[0] % 4 == 0) {
|
|
470
|
+
suffix = "_4";
|
|
347
471
|
}
|
|
348
472
|
|
|
349
|
-
|
|
473
|
+
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
|
|
474
|
+
snprintf(name, 256, "%s", base);
|
|
475
|
+
|
|
476
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
477
|
+
if (!res.pipeline) {
|
|
478
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
479
|
+
}
|
|
350
480
|
|
|
351
481
|
return res;
|
|
352
482
|
}
|
|
353
483
|
|
|
354
|
-
|
|
484
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
|
|
485
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
486
|
+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
487
|
+
|
|
488
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
489
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
|
490
|
+
|
|
355
491
|
char base[256];
|
|
356
492
|
char name[256];
|
|
357
493
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
|
|
494
|
+
const char * suffix = "";
|
|
495
|
+
if (op->src[1]->ne[0] % 4 == 0) {
|
|
496
|
+
suffix = "_4";
|
|
362
497
|
}
|
|
363
|
-
snprintf(name, 256, "%s", base);
|
|
364
498
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
499
|
+
snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
|
|
500
|
+
snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
|
|
501
|
+
|
|
502
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
503
|
+
if (!res.pipeline) {
|
|
504
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
505
|
+
|
|
506
|
+
ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
|
|
507
|
+
|
|
508
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
509
|
+
|
|
510
|
+
ggml_metal_cv_free(cv);
|
|
368
511
|
}
|
|
369
512
|
|
|
370
|
-
res
|
|
513
|
+
return res;
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
517
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
518
|
+
|
|
519
|
+
char base[256];
|
|
520
|
+
char name[256];
|
|
371
521
|
|
|
372
|
-
|
|
522
|
+
const int nsg = (ne00 + 31)/32;
|
|
523
|
+
|
|
524
|
+
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
|
|
525
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
526
|
+
|
|
527
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
528
|
+
if (!res.pipeline) {
|
|
529
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
// Shared memory layout:
|
|
533
|
+
// - sgptg * NW floats for partial sums (nsg * 32)
|
|
534
|
+
// - sgptg floats for shared_x_dt (nsg)
|
|
535
|
+
// - sgptg floats for shared_dA (nsg)
|
|
536
|
+
// Total: nsg * (32 + 2) floats
|
|
537
|
+
res.smem = (32 + 2)*sizeof(float)*nsg;
|
|
373
538
|
|
|
374
539
|
return res;
|
|
375
540
|
}
|
|
376
541
|
|
|
377
|
-
|
|
542
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
378
543
|
char base[256];
|
|
379
544
|
char name[256];
|
|
380
545
|
|
|
@@ -404,41 +569,102 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
|
|
|
404
569
|
|
|
405
570
|
snprintf(name, 256, "%s", base);
|
|
406
571
|
|
|
407
|
-
|
|
408
|
-
if (res) {
|
|
409
|
-
|
|
572
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
573
|
+
if (!res.pipeline) {
|
|
574
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
410
575
|
}
|
|
411
576
|
|
|
412
|
-
res
|
|
577
|
+
return res;
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
581
|
+
char base[256];
|
|
582
|
+
char name[256];
|
|
583
|
+
|
|
584
|
+
// v is src[2], dimensions: S_v = ne[0], H = ne[1]
|
|
585
|
+
const int ne20 = op->src[2]->ne[0]; // S_v
|
|
586
|
+
const int ne21 = op->src[2]->ne[1]; // H
|
|
587
|
+
const int ne30 = op->src[3]->ne[0]; // G
|
|
588
|
+
|
|
589
|
+
const int nsg = op->src[2]->ne[0]/32;
|
|
590
|
+
|
|
591
|
+
GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
|
|
592
|
+
GGML_ASSERT(op->ne[0] == ne20 * ne21);
|
|
593
|
+
GGML_ASSERT(ne20 % 32 == 0);
|
|
594
|
+
|
|
595
|
+
snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
|
|
596
|
+
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
|
|
597
|
+
|
|
598
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
599
|
+
if (!res.pipeline) {
|
|
600
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
601
|
+
|
|
602
|
+
ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
|
|
603
|
+
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
|
|
604
|
+
|
|
605
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
606
|
+
|
|
607
|
+
ggml_metal_cv_free(cv);
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
res.nsg = nsg;
|
|
413
611
|
|
|
414
612
|
return res;
|
|
415
613
|
}
|
|
416
614
|
|
|
417
|
-
|
|
615
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
418
616
|
char base[256];
|
|
419
617
|
char name[256];
|
|
420
618
|
|
|
421
|
-
|
|
422
|
-
|
|
619
|
+
const int nsg = 8;
|
|
620
|
+
const int n = op->src[1]->ne[1];
|
|
621
|
+
const int k = op->src[1]->ne[0];
|
|
423
622
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
623
|
+
snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
|
|
624
|
+
snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
|
|
625
|
+
|
|
626
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
627
|
+
if (!res.pipeline) {
|
|
628
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
629
|
+
|
|
630
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
|
|
631
|
+
ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1);
|
|
632
|
+
ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2);
|
|
633
|
+
|
|
634
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
635
|
+
|
|
636
|
+
ggml_metal_cv_free(cv);
|
|
427
637
|
}
|
|
428
638
|
|
|
429
|
-
|
|
639
|
+
res.nsg = nsg;
|
|
640
|
+
res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
|
|
430
641
|
|
|
431
|
-
|
|
432
|
-
|
|
642
|
+
return res;
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
|
|
646
|
+
char base[256];
|
|
647
|
+
char name[256];
|
|
648
|
+
|
|
649
|
+
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
|
|
650
|
+
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
|
|
651
|
+
|
|
652
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
653
|
+
if (!res.pipeline) {
|
|
654
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
433
655
|
|
|
434
|
-
|
|
656
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
657
|
+
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
|
|
435
658
|
|
|
436
|
-
|
|
659
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
660
|
+
|
|
661
|
+
ggml_metal_cv_free(cv);
|
|
662
|
+
}
|
|
437
663
|
|
|
438
664
|
return res;
|
|
439
665
|
}
|
|
440
666
|
|
|
441
|
-
|
|
667
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
442
668
|
char base[256];
|
|
443
669
|
char name[256];
|
|
444
670
|
|
|
@@ -451,27 +677,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_
|
|
|
451
677
|
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
|
452
678
|
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
|
|
453
679
|
|
|
454
|
-
|
|
455
|
-
if (res) {
|
|
456
|
-
|
|
457
|
-
}
|
|
680
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
681
|
+
if (!res.pipeline) {
|
|
682
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
458
683
|
|
|
459
|
-
|
|
684
|
+
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
|
685
|
+
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
|
|
460
686
|
|
|
461
|
-
|
|
462
|
-
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
|
|
687
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
463
688
|
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
ggml_metal_cv_free(cv);
|
|
689
|
+
ggml_metal_cv_free(cv);
|
|
690
|
+
}
|
|
467
691
|
|
|
468
692
|
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
|
|
469
|
-
|
|
693
|
+
res.smem = bc_out ? 8192 : 4096 + 2048;
|
|
470
694
|
|
|
471
695
|
return res;
|
|
472
696
|
}
|
|
473
697
|
|
|
474
|
-
|
|
698
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
475
699
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
476
700
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
477
701
|
|
|
@@ -626,49 +850,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
|
|
626
850
|
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
|
627
851
|
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
628
852
|
|
|
629
|
-
|
|
630
|
-
if (res) {
|
|
631
|
-
|
|
632
|
-
}
|
|
633
|
-
|
|
634
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
853
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
854
|
+
if (!res.pipeline) {
|
|
855
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
635
856
|
|
|
636
|
-
|
|
857
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
637
858
|
|
|
638
|
-
|
|
859
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
639
860
|
|
|
640
|
-
|
|
861
|
+
ggml_metal_cv_free(cv);
|
|
862
|
+
}
|
|
641
863
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
864
|
+
res.nr0 = nr0;
|
|
865
|
+
res.nr1 = nr1;
|
|
866
|
+
res.nsg = nsg;
|
|
867
|
+
res.smem = smem;
|
|
646
868
|
|
|
647
869
|
return res;
|
|
648
870
|
}
|
|
649
871
|
|
|
650
|
-
|
|
872
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
|
|
651
873
|
char base[256];
|
|
652
874
|
char name[256];
|
|
653
875
|
|
|
654
876
|
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
|
655
|
-
snprintf(name, 256, "%
|
|
877
|
+
snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
|
656
878
|
|
|
657
|
-
|
|
658
|
-
if (res) {
|
|
659
|
-
|
|
879
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
880
|
+
if (!res.pipeline) {
|
|
881
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
660
882
|
}
|
|
661
883
|
|
|
662
|
-
res =
|
|
663
|
-
|
|
664
|
-
const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t);
|
|
665
|
-
|
|
666
|
-
ggml_metal_pipeline_set_smem(res, smem);
|
|
884
|
+
res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
|
|
667
885
|
|
|
668
886
|
return res;
|
|
669
887
|
}
|
|
670
888
|
|
|
671
|
-
|
|
889
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
672
890
|
char base[256];
|
|
673
891
|
char name[256];
|
|
674
892
|
|
|
@@ -680,25 +898,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra
|
|
|
680
898
|
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
|
681
899
|
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
|
|
682
900
|
|
|
683
|
-
|
|
684
|
-
if (res) {
|
|
685
|
-
|
|
686
|
-
}
|
|
901
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
902
|
+
if (!res.pipeline) {
|
|
903
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
687
904
|
|
|
688
|
-
|
|
905
|
+
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
|
689
906
|
|
|
690
|
-
|
|
907
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
691
908
|
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
ggml_metal_cv_free(cv);
|
|
909
|
+
ggml_metal_cv_free(cv);
|
|
910
|
+
}
|
|
695
911
|
|
|
696
|
-
|
|
912
|
+
res.smem = 8192;
|
|
697
913
|
|
|
698
914
|
return res;
|
|
699
915
|
}
|
|
700
916
|
|
|
701
|
-
|
|
917
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
702
918
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
703
919
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
704
920
|
|
|
@@ -846,28 +1062,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
|
|
846
1062
|
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
|
847
1063
|
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
848
1064
|
|
|
849
|
-
|
|
850
|
-
if (res) {
|
|
851
|
-
|
|
852
|
-
}
|
|
853
|
-
|
|
854
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1065
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1066
|
+
if (!res.pipeline) {
|
|
1067
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
855
1068
|
|
|
856
|
-
|
|
1069
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
857
1070
|
|
|
858
|
-
|
|
1071
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
859
1072
|
|
|
860
|
-
|
|
1073
|
+
ggml_metal_cv_free(cv);
|
|
1074
|
+
}
|
|
861
1075
|
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
1076
|
+
res.nr0 = nr0;
|
|
1077
|
+
res.nr1 = nr1;
|
|
1078
|
+
res.nsg = nsg;
|
|
1079
|
+
res.smem = smem;
|
|
866
1080
|
|
|
867
1081
|
return res;
|
|
868
1082
|
}
|
|
869
1083
|
|
|
870
|
-
|
|
1084
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
871
1085
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
872
1086
|
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
|
873
1087
|
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
|
@@ -878,19 +1092,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_
|
|
|
878
1092
|
snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
|
|
879
1093
|
snprintf(name, 256, "%s", base);
|
|
880
1094
|
|
|
881
|
-
|
|
882
|
-
if (res) {
|
|
883
|
-
|
|
1095
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1096
|
+
if (!res.pipeline) {
|
|
1097
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
884
1098
|
}
|
|
885
1099
|
|
|
886
|
-
res =
|
|
1100
|
+
res.smem = 32*(sizeof(float) + sizeof(int32_t));
|
|
1101
|
+
|
|
1102
|
+
return res;
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1106
|
+
assert(op->op == GGML_OP_ARGSORT);
|
|
1107
|
+
|
|
1108
|
+
char base[256];
|
|
1109
|
+
char name[256];
|
|
1110
|
+
|
|
1111
|
+
ggml_sort_order order = (ggml_sort_order) op->op_params[0];
|
|
1112
|
+
|
|
1113
|
+
const char * order_str = "undefined";
|
|
1114
|
+
switch (order) {
|
|
1115
|
+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
1116
|
+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
1117
|
+
default: GGML_ABORT("fatal error");
|
|
1118
|
+
};
|
|
887
1119
|
|
|
888
|
-
|
|
1120
|
+
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
|
1121
|
+
snprintf(name, 256, "%s", base);
|
|
1122
|
+
|
|
1123
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1124
|
+
if (!res.pipeline) {
|
|
1125
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1126
|
+
}
|
|
889
1127
|
|
|
890
1128
|
return res;
|
|
891
1129
|
}
|
|
892
1130
|
|
|
893
|
-
|
|
1131
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
894
1132
|
assert(op->op == GGML_OP_ARGSORT);
|
|
895
1133
|
|
|
896
1134
|
char base[256];
|
|
@@ -905,26 +1143,165 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
|
|
905
1143
|
default: GGML_ABORT("fatal error");
|
|
906
1144
|
};
|
|
907
1145
|
|
|
1146
|
+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
|
1147
|
+
snprintf(name, 256, "%s", base);
|
|
1148
|
+
|
|
1149
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1150
|
+
if (!res.pipeline) {
|
|
1151
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1152
|
+
}
|
|
1153
|
+
|
|
1154
|
+
return res;
|
|
1155
|
+
}
|
|
1156
|
+
|
|
1157
|
+
// note: reuse the argsort kernel for top_k
|
|
1158
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1159
|
+
assert(op->op == GGML_OP_TOP_K);
|
|
1160
|
+
|
|
1161
|
+
char base[256];
|
|
1162
|
+
char name[256];
|
|
1163
|
+
|
|
1164
|
+
// note: the top_k kernel is always descending order
|
|
1165
|
+
ggml_sort_order order = GGML_SORT_ORDER_DESC;
|
|
1166
|
+
|
|
1167
|
+
const char * order_str = "undefined";
|
|
1168
|
+
switch (order) {
|
|
1169
|
+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
1170
|
+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
1171
|
+
default: GGML_ABORT("fatal error");
|
|
1172
|
+
};
|
|
1173
|
+
|
|
908
1174
|
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
|
909
1175
|
snprintf(name, 256, "%s", base);
|
|
910
1176
|
|
|
911
|
-
|
|
912
|
-
if (res) {
|
|
913
|
-
|
|
1177
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1178
|
+
if (!res.pipeline) {
|
|
1179
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
914
1180
|
}
|
|
915
1181
|
|
|
916
|
-
res
|
|
1182
|
+
return res;
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1186
|
+
assert(op->op == GGML_OP_TOP_K);
|
|
1187
|
+
|
|
1188
|
+
char base[256];
|
|
1189
|
+
char name[256];
|
|
1190
|
+
|
|
1191
|
+
ggml_sort_order order = GGML_SORT_ORDER_DESC;
|
|
1192
|
+
|
|
1193
|
+
const char * order_str = "undefined";
|
|
1194
|
+
switch (order) {
|
|
1195
|
+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
1196
|
+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
1197
|
+
default: GGML_ABORT("fatal error");
|
|
1198
|
+
};
|
|
1199
|
+
|
|
1200
|
+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
|
1201
|
+
snprintf(name, 256, "%s", base);
|
|
1202
|
+
|
|
1203
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1204
|
+
if (!res.pipeline) {
|
|
1205
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
return res;
|
|
1209
|
+
}
|
|
1210
|
+
|
|
1211
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
1212
|
+
ggml_metal_library_t lib,
|
|
1213
|
+
const struct ggml_tensor * op,
|
|
1214
|
+
bool has_mask,
|
|
1215
|
+
int32_t ncpsg) {
|
|
1216
|
+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
1217
|
+
GGML_UNUSED(op);
|
|
1218
|
+
|
|
1219
|
+
char base[256];
|
|
1220
|
+
char name[256];
|
|
1221
|
+
|
|
1222
|
+
snprintf(base, 256, "kernel_%s",
|
|
1223
|
+
"flash_attn_ext_pad");
|
|
1224
|
+
|
|
1225
|
+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
|
1226
|
+
base,
|
|
1227
|
+
has_mask,
|
|
1228
|
+
ncpsg);
|
|
1229
|
+
|
|
1230
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1231
|
+
if (!res.pipeline) {
|
|
1232
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1233
|
+
|
|
1234
|
+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
|
1235
|
+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
|
1236
|
+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
|
1237
|
+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
|
1238
|
+
|
|
1239
|
+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
|
1240
|
+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
|
1241
|
+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
|
1242
|
+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
|
1243
|
+
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
|
|
1244
|
+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
|
|
1245
|
+
|
|
1246
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1247
|
+
|
|
1248
|
+
ggml_metal_cv_free(cv);
|
|
1249
|
+
}
|
|
917
1250
|
|
|
918
1251
|
return res;
|
|
919
1252
|
}
|
|
920
1253
|
|
|
921
|
-
|
|
1254
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
|
1255
|
+
ggml_metal_library_t lib,
|
|
1256
|
+
const struct ggml_tensor * op,
|
|
1257
|
+
int32_t nqptg,
|
|
1258
|
+
int32_t ncpsg) {
|
|
1259
|
+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
1260
|
+
GGML_UNUSED(op);
|
|
1261
|
+
|
|
1262
|
+
char base[256];
|
|
1263
|
+
char name[256];
|
|
1264
|
+
|
|
1265
|
+
snprintf(base, 256, "kernel_%s",
|
|
1266
|
+
"flash_attn_ext_blk");
|
|
1267
|
+
|
|
1268
|
+
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
|
|
1269
|
+
base,
|
|
1270
|
+
nqptg,
|
|
1271
|
+
ncpsg);
|
|
1272
|
+
|
|
1273
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1274
|
+
if (!res.pipeline) {
|
|
1275
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1276
|
+
|
|
1277
|
+
//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
|
|
1278
|
+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
|
|
1279
|
+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
|
|
1280
|
+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
|
|
1281
|
+
|
|
1282
|
+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
|
|
1283
|
+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
|
|
1284
|
+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
|
|
1285
|
+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
|
|
1286
|
+
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
|
|
1287
|
+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
|
|
1288
|
+
|
|
1289
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1290
|
+
|
|
1291
|
+
ggml_metal_cv_free(cv);
|
|
1292
|
+
}
|
|
1293
|
+
|
|
1294
|
+
return res;
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1297
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
922
1298
|
ggml_metal_library_t lib,
|
|
923
1299
|
const ggml_tensor * op,
|
|
924
1300
|
bool has_mask,
|
|
925
1301
|
bool has_sinks,
|
|
926
1302
|
bool has_bias,
|
|
927
1303
|
bool has_scap,
|
|
1304
|
+
bool has_kvpad,
|
|
928
1305
|
int32_t nsg) {
|
|
929
1306
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
930
1307
|
|
|
@@ -937,52 +1314,59 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
937
1314
|
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
|
938
1315
|
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
|
939
1316
|
|
|
1317
|
+
// do bounds checks for the mask?
|
|
1318
|
+
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
|
1319
|
+
|
|
940
1320
|
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
|
941
1321
|
"flash_attn_ext",
|
|
942
1322
|
ggml_type_name(op->src[1]->type),
|
|
943
1323
|
dk,
|
|
944
1324
|
dv);
|
|
945
1325
|
|
|
946
|
-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
1326
|
+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
947
1327
|
base,
|
|
948
1328
|
has_mask,
|
|
949
1329
|
has_sinks,
|
|
950
1330
|
has_bias,
|
|
951
1331
|
has_scap,
|
|
1332
|
+
has_kvpad,
|
|
1333
|
+
bc_mask,
|
|
952
1334
|
ns10,
|
|
953
1335
|
ns20,
|
|
954
1336
|
nsg);
|
|
955
1337
|
|
|
956
|
-
|
|
957
|
-
if (res) {
|
|
958
|
-
|
|
959
|
-
}
|
|
1338
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1339
|
+
if (!res.pipeline) {
|
|
1340
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
960
1341
|
|
|
961
|
-
|
|
1342
|
+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
|
|
1343
|
+
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
|
1344
|
+
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
|
1345
|
+
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
|
1346
|
+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
|
962
1347
|
|
|
963
|
-
|
|
964
|
-
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
|
965
|
-
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
|
966
|
-
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
|
1348
|
+
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
|
967
1349
|
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
1350
|
+
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
|
1351
|
+
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
|
1352
|
+
ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
|
|
971
1353
|
|
|
972
|
-
|
|
1354
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
973
1355
|
|
|
974
|
-
|
|
1356
|
+
ggml_metal_cv_free(cv);
|
|
1357
|
+
}
|
|
975
1358
|
|
|
976
1359
|
return res;
|
|
977
1360
|
}
|
|
978
1361
|
|
|
979
|
-
|
|
1362
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|
980
1363
|
ggml_metal_library_t lib,
|
|
981
1364
|
const ggml_tensor * op,
|
|
982
1365
|
bool has_mask,
|
|
983
1366
|
bool has_sinks,
|
|
984
1367
|
bool has_bias,
|
|
985
1368
|
bool has_scap,
|
|
1369
|
+
bool has_kvpad,
|
|
986
1370
|
int32_t nsg,
|
|
987
1371
|
int32_t nwg) {
|
|
988
1372
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
@@ -1002,41 +1386,41 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|
|
1002
1386
|
dk,
|
|
1003
1387
|
dv);
|
|
1004
1388
|
|
|
1005
|
-
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%
|
|
1389
|
+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
|
1006
1390
|
base,
|
|
1007
1391
|
has_mask,
|
|
1008
1392
|
has_sinks,
|
|
1009
1393
|
has_bias,
|
|
1010
1394
|
has_scap,
|
|
1395
|
+
has_kvpad,
|
|
1011
1396
|
ns10,
|
|
1012
1397
|
ns20,
|
|
1013
1398
|
nsg, nwg);
|
|
1014
1399
|
|
|
1015
|
-
|
|
1016
|
-
if (res) {
|
|
1017
|
-
|
|
1018
|
-
}
|
|
1019
|
-
|
|
1020
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1400
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1401
|
+
if (!res.pipeline) {
|
|
1402
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1021
1403
|
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1404
|
+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
|
|
1405
|
+
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
|
1406
|
+
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
|
1407
|
+
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
|
1408
|
+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
|
1026
1409
|
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1410
|
+
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
|
1411
|
+
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
|
1412
|
+
ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
|
|
1413
|
+
ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
|
|
1031
1414
|
|
|
1032
|
-
|
|
1415
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1033
1416
|
|
|
1034
|
-
|
|
1417
|
+
ggml_metal_cv_free(cv);
|
|
1418
|
+
}
|
|
1035
1419
|
|
|
1036
1420
|
return res;
|
|
1037
1421
|
}
|
|
1038
1422
|
|
|
1039
|
-
|
|
1423
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
|
1040
1424
|
ggml_metal_library_t lib,
|
|
1041
1425
|
const ggml_tensor * op,
|
|
1042
1426
|
int32_t dv,
|
|
@@ -1049,85 +1433,128 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
|
|
1049
1433
|
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
|
|
1050
1434
|
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
|
|
1051
1435
|
|
|
1052
|
-
|
|
1053
|
-
if (res) {
|
|
1054
|
-
|
|
1055
|
-
}
|
|
1056
|
-
|
|
1057
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1436
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1437
|
+
if (!res.pipeline) {
|
|
1438
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1058
1439
|
|
|
1059
|
-
|
|
1060
|
-
|
|
1440
|
+
ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
|
|
1441
|
+
ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
|
|
1061
1442
|
|
|
1062
|
-
|
|
1443
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1063
1444
|
|
|
1064
|
-
|
|
1445
|
+
ggml_metal_cv_free(cv);
|
|
1446
|
+
}
|
|
1065
1447
|
|
|
1066
1448
|
return res;
|
|
1067
1449
|
|
|
1068
1450
|
GGML_UNUSED(op);
|
|
1069
1451
|
}
|
|
1070
1452
|
|
|
1071
|
-
|
|
1072
|
-
ggml_metal_library_t lib,
|
|
1073
|
-
ggml_op op,
|
|
1074
|
-
int32_t n_fuse,
|
|
1075
|
-
bool row) {
|
|
1453
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
|
|
1076
1454
|
char base[256];
|
|
1077
1455
|
char name[256];
|
|
1078
1456
|
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
case
|
|
1083
|
-
case
|
|
1084
|
-
case
|
|
1457
|
+
int op_num = -1;
|
|
1458
|
+
|
|
1459
|
+
switch (op->op) {
|
|
1460
|
+
case GGML_OP_ADD: op_num = 0; break;
|
|
1461
|
+
case GGML_OP_SUB: op_num = 1; break;
|
|
1462
|
+
case GGML_OP_MUL: op_num = 2; break;
|
|
1463
|
+
case GGML_OP_DIV: op_num = 3; break;
|
|
1085
1464
|
default: GGML_ABORT("fatal error");
|
|
1086
1465
|
};
|
|
1087
1466
|
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
|
|
1092
|
-
}
|
|
1467
|
+
const char * t0_str = ggml_type_name(op->src[0]->type);
|
|
1468
|
+
const char * t1_str = ggml_type_name(op->src[1]->type);
|
|
1469
|
+
const char * t_str = ggml_type_name(op->type);
|
|
1093
1470
|
|
|
1094
|
-
|
|
1471
|
+
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
|
|
1095
1472
|
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1473
|
+
const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];
|
|
1474
|
+
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
|
|
1475
|
+
|
|
1476
|
+
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
|
|
1477
|
+
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb);
|
|
1478
|
+
|
|
1479
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1480
|
+
if (!res.pipeline) {
|
|
1481
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1482
|
+
|
|
1483
|
+
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
|
|
1484
|
+
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
|
|
1485
|
+
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
|
|
1486
|
+
ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3);
|
|
1487
|
+
|
|
1488
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1489
|
+
|
|
1490
|
+
ggml_metal_cv_free(cv);
|
|
1099
1491
|
}
|
|
1100
1492
|
|
|
1101
|
-
res
|
|
1493
|
+
res.c4 = is_c4;
|
|
1494
|
+
res.cnt = is_rb;
|
|
1102
1495
|
|
|
1103
1496
|
return res;
|
|
1104
1497
|
}
|
|
1105
1498
|
|
|
1106
|
-
|
|
1107
|
-
|
|
1499
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
|
|
1500
|
+
char base[256];
|
|
1501
|
+
char name[256];
|
|
1108
1502
|
|
|
1109
|
-
|
|
1110
|
-
|
|
1503
|
+
int op_num = -1;
|
|
1504
|
+
|
|
1505
|
+
switch (op) {
|
|
1506
|
+
case GGML_OP_ADD: op_num = 0; break;
|
|
1507
|
+
case GGML_OP_SUB: op_num = 1; break;
|
|
1508
|
+
case GGML_OP_MUL: op_num = 2; break;
|
|
1509
|
+
case GGML_OP_DIV: op_num = 3; break;
|
|
1510
|
+
default: GGML_ABORT("fatal error");
|
|
1511
|
+
};
|
|
1512
|
+
|
|
1513
|
+
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
|
|
1514
|
+
snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
|
|
1515
|
+
|
|
1516
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1517
|
+
if (!res.pipeline) {
|
|
1518
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1519
|
+
|
|
1520
|
+
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
|
|
1521
|
+
ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
|
|
1522
|
+
ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);
|
|
1523
|
+
|
|
1524
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1525
|
+
|
|
1526
|
+
ggml_metal_cv_free(cv);
|
|
1527
|
+
}
|
|
1528
|
+
|
|
1529
|
+
return res;
|
|
1530
|
+
}
|
|
1531
|
+
|
|
1532
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1533
|
+
assert(op->op == GGML_OP_L2_NORM);
|
|
1111
1534
|
|
|
1112
1535
|
char base[256];
|
|
1113
1536
|
char name[256];
|
|
1114
1537
|
|
|
1115
|
-
|
|
1538
|
+
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
|
1539
|
+
|
|
1540
|
+
const char * t0_str = ggml_type_name(op->src[0]->type);
|
|
1541
|
+
const char * t_str = ggml_type_name(op->type);
|
|
1542
|
+
|
|
1543
|
+
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
|
1116
1544
|
snprintf(name, 256, "%s", base);
|
|
1117
1545
|
|
|
1118
|
-
|
|
1119
|
-
if (res) {
|
|
1120
|
-
|
|
1546
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1547
|
+
if (!res.pipeline) {
|
|
1548
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1121
1549
|
}
|
|
1122
1550
|
|
|
1123
|
-
res
|
|
1124
|
-
|
|
1125
|
-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1551
|
+
res.c4 = is_c4;
|
|
1552
|
+
res.smem = 32*sizeof(float);
|
|
1126
1553
|
|
|
1127
1554
|
return res;
|
|
1128
1555
|
}
|
|
1129
1556
|
|
|
1130
|
-
|
|
1557
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1131
1558
|
assert(op->op == GGML_OP_GROUP_NORM);
|
|
1132
1559
|
|
|
1133
1560
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
@@ -1138,19 +1565,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
|
|
|
1138
1565
|
snprintf(base, 256, "kernel_group_norm_f32");
|
|
1139
1566
|
snprintf(name, 256, "%s", base);
|
|
1140
1567
|
|
|
1141
|
-
|
|
1142
|
-
if (res) {
|
|
1143
|
-
|
|
1568
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1569
|
+
if (!res.pipeline) {
|
|
1570
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1144
1571
|
}
|
|
1145
1572
|
|
|
1146
|
-
res =
|
|
1147
|
-
|
|
1148
|
-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1573
|
+
res.smem = 32*sizeof(float);
|
|
1149
1574
|
|
|
1150
1575
|
return res;
|
|
1151
1576
|
}
|
|
1152
1577
|
|
|
1153
|
-
|
|
1578
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
|
|
1154
1579
|
assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
|
|
1155
1580
|
|
|
1156
1581
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
@@ -1183,19 +1608,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t
|
|
|
1183
1608
|
|
|
1184
1609
|
snprintf(name, 256, "%s", base);
|
|
1185
1610
|
|
|
1186
|
-
|
|
1187
|
-
if (res) {
|
|
1188
|
-
|
|
1611
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1612
|
+
if (!res.pipeline) {
|
|
1613
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1189
1614
|
}
|
|
1190
1615
|
|
|
1191
|
-
res =
|
|
1192
|
-
|
|
1193
|
-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1616
|
+
res.smem = 32*sizeof(float);
|
|
1194
1617
|
|
|
1195
1618
|
return res;
|
|
1196
1619
|
}
|
|
1197
1620
|
|
|
1198
|
-
|
|
1621
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1199
1622
|
assert(op->op == GGML_OP_ROPE);
|
|
1200
1623
|
|
|
1201
1624
|
char base[256];
|
|
@@ -1205,11 +1628,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
|
|
|
1205
1628
|
|
|
1206
1629
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
1207
1630
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
1631
|
+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
|
1208
1632
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
1209
1633
|
|
|
1210
1634
|
if (is_neox) {
|
|
1211
1635
|
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
|
|
1212
|
-
} else if (is_mrope && !is_vision) {
|
|
1636
|
+
} else if ((is_mrope || is_imrope) && !is_vision) {
|
|
1213
1637
|
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
|
1214
1638
|
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
|
|
1215
1639
|
} else if (is_vision) {
|
|
@@ -1219,19 +1643,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
|
|
|
1219
1643
|
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
|
|
1220
1644
|
}
|
|
1221
1645
|
|
|
1222
|
-
snprintf(name, 256, "%
|
|
1646
|
+
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
|
1223
1647
|
|
|
1224
|
-
|
|
1225
|
-
if (res) {
|
|
1226
|
-
|
|
1227
|
-
}
|
|
1648
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1649
|
+
if (!res.pipeline) {
|
|
1650
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1228
1651
|
|
|
1229
|
-
|
|
1652
|
+
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
|
1653
|
+
|
|
1654
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1655
|
+
|
|
1656
|
+
ggml_metal_cv_free(cv);
|
|
1657
|
+
}
|
|
1230
1658
|
|
|
1231
1659
|
return res;
|
|
1232
1660
|
}
|
|
1233
1661
|
|
|
1234
|
-
|
|
1662
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1235
1663
|
assert(op->op == GGML_OP_IM2COL);
|
|
1236
1664
|
|
|
1237
1665
|
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
|
@@ -1244,17 +1672,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
|
|
|
1244
1672
|
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
|
|
1245
1673
|
snprintf(name, 256, "%s", base);
|
|
1246
1674
|
|
|
1247
|
-
|
|
1248
|
-
if (res) {
|
|
1249
|
-
|
|
1675
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1676
|
+
if (!res.pipeline) {
|
|
1677
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1250
1678
|
}
|
|
1251
1679
|
|
|
1252
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1253
|
-
|
|
1254
1680
|
return res;
|
|
1255
1681
|
}
|
|
1256
1682
|
|
|
1257
|
-
|
|
1683
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1258
1684
|
assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
|
|
1259
1685
|
|
|
1260
1686
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
@@ -1269,36 +1695,94 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
|
|
|
1269
1695
|
snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
|
1270
1696
|
snprintf(name, 256, "%s", base);
|
|
1271
1697
|
|
|
1272
|
-
|
|
1273
|
-
if (res) {
|
|
1274
|
-
|
|
1698
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1699
|
+
if (!res.pipeline) {
|
|
1700
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1275
1701
|
}
|
|
1276
1702
|
|
|
1277
|
-
res
|
|
1703
|
+
return res;
|
|
1704
|
+
}
|
|
1705
|
+
|
|
1706
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1707
|
+
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
|
|
1708
|
+
|
|
1709
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
1710
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
|
1711
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
|
1712
|
+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
1713
|
+
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
1714
|
+
|
|
1715
|
+
char base[256];
|
|
1716
|
+
char name[256];
|
|
1717
|
+
|
|
1718
|
+
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
|
1719
|
+
snprintf(name, 256, "%s", base);
|
|
1720
|
+
|
|
1721
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1722
|
+
if (!res.pipeline) {
|
|
1723
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1724
|
+
}
|
|
1278
1725
|
|
|
1279
1726
|
return res;
|
|
1280
1727
|
}
|
|
1281
1728
|
|
|
1282
|
-
|
|
1283
|
-
assert(op->op ==
|
|
1729
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1730
|
+
assert(op->op == GGML_OP_CONV_2D);
|
|
1731
|
+
|
|
1732
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
1733
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
|
1734
|
+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
1735
|
+
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
1284
1736
|
|
|
1285
1737
|
char base[256];
|
|
1286
1738
|
char name[256];
|
|
1287
1739
|
|
|
1288
|
-
snprintf(base, 256, "
|
|
1740
|
+
snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
|
1289
1741
|
snprintf(name, 256, "%s", base);
|
|
1290
1742
|
|
|
1291
|
-
|
|
1292
|
-
if (res) {
|
|
1293
|
-
|
|
1743
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1744
|
+
if (!res.pipeline) {
|
|
1745
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1294
1746
|
}
|
|
1295
1747
|
|
|
1296
|
-
res
|
|
1748
|
+
return res;
|
|
1749
|
+
}
|
|
1750
|
+
|
|
1751
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1752
|
+
assert(op->op == GGML_OP_UPSCALE);
|
|
1753
|
+
|
|
1754
|
+
char base[256];
|
|
1755
|
+
char name[256];
|
|
1756
|
+
|
|
1757
|
+
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
|
|
1758
|
+
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
1759
|
+
|
|
1760
|
+
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
|
|
1761
|
+
|
|
1762
|
+
if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
1763
|
+
snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
|
|
1764
|
+
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
|
1765
|
+
snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
|
|
1766
|
+
} else {
|
|
1767
|
+
snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
|
|
1768
|
+
}
|
|
1769
|
+
snprintf(name, 256, "%s_aa=%d", base, antialias);
|
|
1770
|
+
|
|
1771
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1772
|
+
if (!res.pipeline) {
|
|
1773
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1774
|
+
|
|
1775
|
+
ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
|
|
1776
|
+
|
|
1777
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1778
|
+
|
|
1779
|
+
ggml_metal_cv_free(cv);
|
|
1780
|
+
}
|
|
1297
1781
|
|
|
1298
1782
|
return res;
|
|
1299
1783
|
}
|
|
1300
1784
|
|
|
1301
|
-
|
|
1785
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1302
1786
|
assert(op->op == GGML_OP_PAD);
|
|
1303
1787
|
|
|
1304
1788
|
char base[256];
|
|
@@ -1307,8 +1791,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
|
|
|
1307
1791
|
snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
|
|
1308
1792
|
snprintf(name, 256, "%s", base);
|
|
1309
1793
|
|
|
1310
|
-
|
|
1311
|
-
if (res) {
|
|
1794
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1795
|
+
if (res.pipeline) {
|
|
1312
1796
|
return res;
|
|
1313
1797
|
}
|
|
1314
1798
|
|
|
@@ -1317,7 +1801,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
|
|
|
1317
1801
|
return res;
|
|
1318
1802
|
}
|
|
1319
1803
|
|
|
1320
|
-
|
|
1804
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1321
1805
|
assert(op->op == GGML_OP_PAD_REFLECT_1D);
|
|
1322
1806
|
|
|
1323
1807
|
char base[256];
|
|
@@ -1326,17 +1810,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_
|
|
|
1326
1810
|
snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
|
|
1327
1811
|
snprintf(name, 256, "%s", base);
|
|
1328
1812
|
|
|
1329
|
-
|
|
1330
|
-
if (res) {
|
|
1331
|
-
|
|
1813
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1814
|
+
if (!res.pipeline) {
|
|
1815
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1332
1816
|
}
|
|
1333
1817
|
|
|
1334
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1335
|
-
|
|
1336
1818
|
return res;
|
|
1337
1819
|
}
|
|
1338
1820
|
|
|
1339
|
-
|
|
1821
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1340
1822
|
assert(op->op == GGML_OP_ARANGE);
|
|
1341
1823
|
|
|
1342
1824
|
char base[256];
|
|
@@ -1345,17 +1827,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_
|
|
|
1345
1827
|
snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
|
|
1346
1828
|
snprintf(name, 256, "%s", base);
|
|
1347
1829
|
|
|
1348
|
-
|
|
1349
|
-
if (res) {
|
|
1350
|
-
|
|
1830
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1831
|
+
if (!res.pipeline) {
|
|
1832
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1351
1833
|
}
|
|
1352
1834
|
|
|
1353
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1354
|
-
|
|
1355
1835
|
return res;
|
|
1356
1836
|
}
|
|
1357
1837
|
|
|
1358
|
-
|
|
1838
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1359
1839
|
assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
|
|
1360
1840
|
|
|
1361
1841
|
char base[256];
|
|
@@ -1364,13 +1844,101 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
|
|
|
1364
1844
|
snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
|
|
1365
1845
|
snprintf(name, 256, "%s", base);
|
|
1366
1846
|
|
|
1367
|
-
|
|
1368
|
-
if (res) {
|
|
1369
|
-
|
|
1847
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1848
|
+
if (!res.pipeline) {
|
|
1849
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1370
1850
|
}
|
|
1371
1851
|
|
|
1372
|
-
res
|
|
1852
|
+
return res;
|
|
1853
|
+
}
|
|
1854
|
+
|
|
1855
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1856
|
+
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
|
|
1857
|
+
|
|
1858
|
+
char base[256];
|
|
1859
|
+
char name[256];
|
|
1860
|
+
|
|
1861
|
+
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
|
|
1862
|
+
snprintf(name, 256, "%s", base);
|
|
1863
|
+
|
|
1864
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1865
|
+
if (!res.pipeline) {
|
|
1866
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1867
|
+
}
|
|
1868
|
+
|
|
1869
|
+
return res;
|
|
1870
|
+
}
|
|
1871
|
+
|
|
1872
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1873
|
+
assert(op->op == GGML_OP_OPT_STEP_SGD);
|
|
1874
|
+
|
|
1875
|
+
char base[256];
|
|
1876
|
+
char name[256];
|
|
1877
|
+
|
|
1878
|
+
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
|
|
1879
|
+
snprintf(name, 256, "%s", base);
|
|
1880
|
+
|
|
1881
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1882
|
+
if (!res.pipeline) {
|
|
1883
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1884
|
+
}
|
|
1373
1885
|
|
|
1374
1886
|
return res;
|
|
1375
1887
|
}
|
|
1376
1888
|
|
|
1889
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1890
|
+
GGML_ASSERT(op->type == GGML_TYPE_I64);
|
|
1891
|
+
|
|
1892
|
+
char base[256];
|
|
1893
|
+
char name[256];
|
|
1894
|
+
|
|
1895
|
+
snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
|
|
1896
|
+
snprintf(name, 256, "%s", base);
|
|
1897
|
+
|
|
1898
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1899
|
+
if (!res.pipeline) {
|
|
1900
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1901
|
+
}
|
|
1902
|
+
|
|
1903
|
+
return res;
|
|
1904
|
+
}
|
|
1905
|
+
|
|
1906
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1907
|
+
assert(op->op == GGML_OP_COUNT_EQUAL);
|
|
1908
|
+
|
|
1909
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
|
|
1910
|
+
|
|
1911
|
+
GGML_ASSERT(op->src[0]->type == op->src[1]->type);
|
|
1912
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
|
|
1913
|
+
GGML_ASSERT(op->type == GGML_TYPE_I64);
|
|
1914
|
+
|
|
1915
|
+
// note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
|
|
1916
|
+
GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
|
|
1917
|
+
|
|
1918
|
+
char base[256];
|
|
1919
|
+
char name[256];
|
|
1920
|
+
|
|
1921
|
+
int nsg = 1;
|
|
1922
|
+
while (32*nsg < ne00 && nsg < 32) {
|
|
1923
|
+
nsg *= 2;
|
|
1924
|
+
}
|
|
1925
|
+
|
|
1926
|
+
snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
|
|
1927
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
1928
|
+
|
|
1929
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1930
|
+
if (!res.pipeline) {
|
|
1931
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1932
|
+
|
|
1933
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
|
|
1934
|
+
|
|
1935
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1936
|
+
|
|
1937
|
+
ggml_metal_cv_free(cv);
|
|
1938
|
+
}
|
|
1939
|
+
|
|
1940
|
+
res.smem = 32 * sizeof(int32_t);
|
|
1941
|
+
res.nsg = nsg;
|
|
1942
|
+
|
|
1943
|
+
return res;
|
|
1944
|
+
}
|