whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -10,6 +10,8 @@
|
|
|
10
10
|
|
|
11
11
|
#include <cassert>
|
|
12
12
|
#include <algorithm>
|
|
13
|
+
#include <limits>
|
|
14
|
+
#include <cmath>
|
|
13
15
|
|
|
14
16
|
static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
|
|
15
17
|
if (!t) {
|
|
@@ -201,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
201
203
|
GGML_ABORT("unsupported op");
|
|
202
204
|
}
|
|
203
205
|
|
|
206
|
+
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
|
207
|
+
return 1;
|
|
208
|
+
}
|
|
209
|
+
|
|
204
210
|
int n_fuse = 1;
|
|
205
211
|
|
|
206
212
|
// check if the current node can run concurrently with other nodes before it
|
|
@@ -219,13 +225,17 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
219
225
|
}
|
|
220
226
|
|
|
221
227
|
if (ctx->debug_graph > 0) {
|
|
222
|
-
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : "");
|
|
228
|
+
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
|
|
223
229
|
}
|
|
224
230
|
if (ctx->debug_graph > 1) {
|
|
225
231
|
GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
|
|
226
232
|
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
|
227
233
|
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
|
228
234
|
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
|
235
|
+
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
|
236
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
|
237
|
+
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
|
238
|
+
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
|
229
239
|
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
|
230
240
|
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
|
231
241
|
|
|
@@ -237,6 +247,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
237
247
|
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
238
248
|
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
|
239
249
|
}
|
|
250
|
+
if (node->src[2]) {
|
|
251
|
+
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
|
252
|
+
ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
|
253
|
+
}
|
|
254
|
+
if (node->src[3]) {
|
|
255
|
+
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
|
256
|
+
ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
|
257
|
+
}
|
|
240
258
|
if (node) {
|
|
241
259
|
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
|
242
260
|
node->name);
|
|
@@ -269,13 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
269
287
|
n_fuse = ggml_metal_op_acc(ctx, idx);
|
|
270
288
|
} break;
|
|
271
289
|
case GGML_OP_SCALE:
|
|
272
|
-
|
|
273
|
-
n_fuse = ggml_metal_op_scale(ctx, idx);
|
|
274
|
-
} break;
|
|
290
|
+
case GGML_OP_FILL:
|
|
275
291
|
case GGML_OP_CLAMP:
|
|
276
|
-
|
|
277
|
-
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
|
278
|
-
} break;
|
|
292
|
+
case GGML_OP_LEAKY_RELU:
|
|
279
293
|
case GGML_OP_SQR:
|
|
280
294
|
case GGML_OP_SQRT:
|
|
281
295
|
case GGML_OP_SIN:
|
|
@@ -289,11 +303,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
289
303
|
{
|
|
290
304
|
n_fuse = ggml_metal_op_glu(ctx, idx);
|
|
291
305
|
} break;
|
|
306
|
+
case GGML_OP_SUM:
|
|
307
|
+
{
|
|
308
|
+
n_fuse = ggml_metal_op_sum(ctx, idx);
|
|
309
|
+
} break;
|
|
292
310
|
case GGML_OP_SUM_ROWS:
|
|
293
311
|
case GGML_OP_MEAN:
|
|
294
312
|
{
|
|
295
313
|
n_fuse = ggml_metal_op_sum_rows(ctx, idx);
|
|
296
314
|
} break;
|
|
315
|
+
case GGML_OP_CUMSUM:
|
|
316
|
+
{
|
|
317
|
+
n_fuse = ggml_metal_op_cumsum(ctx, idx);
|
|
318
|
+
} break;
|
|
297
319
|
case GGML_OP_SOFT_MAX:
|
|
298
320
|
{
|
|
299
321
|
n_fuse = ggml_metal_op_soft_max(ctx, idx);
|
|
@@ -311,6 +333,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
311
333
|
{
|
|
312
334
|
n_fuse = ggml_metal_op_rwkv(ctx, idx);
|
|
313
335
|
} break;
|
|
336
|
+
case GGML_OP_GATED_DELTA_NET:
|
|
337
|
+
{
|
|
338
|
+
n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
|
|
339
|
+
} break;
|
|
340
|
+
case GGML_OP_SOLVE_TRI:
|
|
341
|
+
{
|
|
342
|
+
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
|
343
|
+
} break;
|
|
314
344
|
case GGML_OP_MUL_MAT:
|
|
315
345
|
{
|
|
316
346
|
n_fuse = ggml_metal_op_mul_mat(ctx, idx);
|
|
@@ -327,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
327
357
|
{
|
|
328
358
|
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
|
329
359
|
} break;
|
|
360
|
+
case GGML_OP_DIAG:
|
|
361
|
+
{
|
|
362
|
+
n_fuse = ggml_metal_op_diag(ctx, idx);
|
|
363
|
+
} break;
|
|
330
364
|
case GGML_OP_L2_NORM:
|
|
331
365
|
{
|
|
332
366
|
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
|
@@ -348,10 +382,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
348
382
|
{
|
|
349
383
|
n_fuse = ggml_metal_op_im2col(ctx, idx);
|
|
350
384
|
} break;
|
|
385
|
+
case GGML_OP_CONV_2D:
|
|
386
|
+
{
|
|
387
|
+
n_fuse = ggml_metal_op_conv_2d(ctx, idx);
|
|
388
|
+
} break;
|
|
351
389
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
352
390
|
{
|
|
353
391
|
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
|
|
354
392
|
} break;
|
|
393
|
+
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
394
|
+
{
|
|
395
|
+
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
|
|
396
|
+
} break;
|
|
355
397
|
case GGML_OP_UPSCALE:
|
|
356
398
|
{
|
|
357
399
|
n_fuse = ggml_metal_op_upscale(ctx, idx);
|
|
@@ -376,20 +418,32 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
376
418
|
{
|
|
377
419
|
n_fuse = ggml_metal_op_argsort(ctx, idx);
|
|
378
420
|
} break;
|
|
379
|
-
case
|
|
421
|
+
case GGML_OP_TOP_K:
|
|
422
|
+
{
|
|
423
|
+
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
|
424
|
+
} break;
|
|
425
|
+
case GGML_OP_TRI:
|
|
380
426
|
{
|
|
381
|
-
n_fuse =
|
|
427
|
+
n_fuse = ggml_metal_op_tri(ctx, idx);
|
|
382
428
|
} break;
|
|
383
429
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
384
430
|
{
|
|
385
431
|
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
|
386
432
|
} break;
|
|
433
|
+
case GGML_OP_SET:
|
|
434
|
+
{
|
|
435
|
+
n_fuse = ggml_metal_op_set(ctx, idx);
|
|
436
|
+
} break;
|
|
387
437
|
case GGML_OP_DUP:
|
|
388
438
|
case GGML_OP_CPY:
|
|
389
439
|
case GGML_OP_CONT:
|
|
390
440
|
{
|
|
391
441
|
n_fuse = ggml_metal_op_cpy(ctx, idx);
|
|
392
442
|
} break;
|
|
443
|
+
case GGML_OP_POOL_1D:
|
|
444
|
+
{
|
|
445
|
+
n_fuse = ggml_metal_op_pool_1d(ctx, idx);
|
|
446
|
+
} break;
|
|
393
447
|
case GGML_OP_POOL_2D:
|
|
394
448
|
{
|
|
395
449
|
n_fuse = ggml_metal_op_pool_2d(ctx, idx);
|
|
@@ -398,7 +452,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
398
452
|
{
|
|
399
453
|
n_fuse = ggml_metal_op_argmax(ctx, idx);
|
|
400
454
|
} break;
|
|
401
|
-
|
|
455
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
456
|
+
{
|
|
457
|
+
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
|
|
458
|
+
} break;
|
|
459
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
460
|
+
{
|
|
461
|
+
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
|
462
|
+
} break;
|
|
463
|
+
case GGML_OP_COUNT_EQUAL:
|
|
464
|
+
{
|
|
465
|
+
n_fuse = ggml_metal_op_count_equal(ctx, idx);
|
|
466
|
+
} break;
|
|
467
|
+
default:
|
|
402
468
|
{
|
|
403
469
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
|
404
470
|
GGML_ABORT("fatal error");
|
|
@@ -482,7 +548,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
|
|
482
548
|
/*.dim =*/ dim,
|
|
483
549
|
};
|
|
484
550
|
|
|
485
|
-
|
|
551
|
+
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
|
|
486
552
|
|
|
487
553
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
488
554
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -506,9 +572,9 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
|
|
506
572
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
507
573
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
508
574
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
509
|
-
GGML_TENSOR_LOCALS(
|
|
575
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
510
576
|
|
|
511
|
-
|
|
577
|
+
auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
|
512
578
|
|
|
513
579
|
ggml_metal_kargs_repeat args = {
|
|
514
580
|
/*.ne00 =*/ ne00,
|
|
@@ -552,14 +618,14 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
552
618
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
553
619
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
554
620
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
555
|
-
GGML_TENSOR_LOCALS(
|
|
621
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
556
622
|
|
|
557
623
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
558
624
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
559
625
|
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
560
626
|
|
|
561
|
-
GGML_ASSERT(
|
|
562
|
-
GGML_ASSERT(
|
|
627
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
628
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
|
563
629
|
|
|
564
630
|
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
|
565
631
|
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
|
@@ -569,14 +635,15 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
569
635
|
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
|
570
636
|
|
|
571
637
|
if (!inplace) {
|
|
572
|
-
// run a
|
|
638
|
+
// run a separate kernel to cpy src->dst
|
|
573
639
|
// not sure how to avoid this
|
|
574
640
|
// TODO: make a simpler cpy_bytes kernel
|
|
575
641
|
|
|
576
642
|
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
|
577
|
-
|
|
643
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
578
644
|
|
|
579
645
|
ggml_metal_kargs_cpy args = {
|
|
646
|
+
/*.nk0 =*/ ne00,
|
|
580
647
|
/*.ne00 =*/ ne00,
|
|
581
648
|
/*.ne01 =*/ ne01,
|
|
582
649
|
/*.ne02 =*/ ne02,
|
|
@@ -608,10 +675,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
608
675
|
}
|
|
609
676
|
|
|
610
677
|
ggml_metal_kargs_bin args = {
|
|
611
|
-
/*.ne00 =*/
|
|
612
|
-
/*.ne01 =*/
|
|
613
|
-
/*.ne02 =*/
|
|
614
|
-
/*.ne03 =*/
|
|
678
|
+
/*.ne00 =*/ ne10,
|
|
679
|
+
/*.ne01 =*/ ne11,
|
|
680
|
+
/*.ne02 =*/ ne12,
|
|
681
|
+
/*.ne03 =*/ ne13,
|
|
615
682
|
/*.nb00 =*/ nb00,
|
|
616
683
|
/*.nb01 =*/ pnb1,
|
|
617
684
|
/*.nb02 =*/ pnb2,
|
|
@@ -624,10 +691,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
624
691
|
/*.nb11 =*/ nb11,
|
|
625
692
|
/*.nb12 =*/ nb12,
|
|
626
693
|
/*.nb13 =*/ nb13,
|
|
627
|
-
/*.ne0 =*/
|
|
628
|
-
/*.ne1 =*/
|
|
629
|
-
/*.ne2 =*/
|
|
630
|
-
/*.ne3 =*/
|
|
694
|
+
/*.ne0 =*/ ne10,
|
|
695
|
+
/*.ne1 =*/ ne11,
|
|
696
|
+
/*.ne2 =*/ ne12,
|
|
697
|
+
/*.ne3 =*/ ne13,
|
|
631
698
|
/*.nb0 =*/ nb0,
|
|
632
699
|
/*.nb1 =*/ pnb1,
|
|
633
700
|
/*.nb2 =*/ pnb2,
|
|
@@ -636,7 +703,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
636
703
|
/*.o1 =*/ { 0 },
|
|
637
704
|
};
|
|
638
705
|
|
|
639
|
-
|
|
706
|
+
auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
|
|
640
707
|
|
|
641
708
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
642
709
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -644,14 +711,20 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
644
711
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
645
712
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
646
713
|
|
|
647
|
-
const int
|
|
714
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
715
|
+
|
|
716
|
+
int nth = 1;
|
|
717
|
+
|
|
718
|
+
while (2*nth < args.ne0 && nth < nth_max) {
|
|
719
|
+
nth *= 2;
|
|
720
|
+
}
|
|
648
721
|
|
|
649
722
|
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
|
650
723
|
|
|
651
724
|
return 1;
|
|
652
725
|
}
|
|
653
726
|
|
|
654
|
-
int
|
|
727
|
+
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|
655
728
|
ggml_tensor * op = ctx->node(idx);
|
|
656
729
|
|
|
657
730
|
ggml_metal_library_t lib = ctx->lib;
|
|
@@ -660,100 +733,82 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
|
|
660
733
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
661
734
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
662
735
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
663
|
-
GGML_TENSOR_LOCALS(
|
|
736
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
664
737
|
|
|
665
|
-
|
|
666
|
-
float bias;
|
|
667
|
-
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
|
668
|
-
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
|
738
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
669
739
|
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
/*.bias =*/ bias,
|
|
673
|
-
};
|
|
740
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
741
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
674
742
|
|
|
675
|
-
|
|
743
|
+
ggml_metal_kargs_unary args = {
|
|
744
|
+
/*.ne00 =*/ ne00,
|
|
745
|
+
/*.ne01 =*/ ne01,
|
|
746
|
+
/*.ne02 =*/ ne02,
|
|
747
|
+
/*.ne03 =*/ ne03,
|
|
748
|
+
/*.nb00 =*/ nb00,
|
|
749
|
+
/*.nb01 =*/ nb01,
|
|
750
|
+
/*.nb02 =*/ nb02,
|
|
751
|
+
/*.nb03 =*/ nb03,
|
|
752
|
+
/*.ne0 =*/ ne0,
|
|
753
|
+
/*.ne1 =*/ ne1,
|
|
754
|
+
/*.ne2 =*/ ne2,
|
|
755
|
+
/*.ne3 =*/ ne3,
|
|
756
|
+
/*.nb0 =*/ nb0,
|
|
757
|
+
/*.nb1 =*/ nb1,
|
|
758
|
+
/*.nb2 =*/ nb2,
|
|
759
|
+
/*.nb3 =*/ nb3,
|
|
760
|
+
/*.slope =*/ 0.0,
|
|
761
|
+
/*.scale =*/ 0.0,
|
|
762
|
+
/*.bias =*/ 0.0,
|
|
763
|
+
/*.val =*/ 0.0,
|
|
764
|
+
/*.min =*/ 0.0,
|
|
765
|
+
/*.max =*/ 0.0,
|
|
766
|
+
};
|
|
676
767
|
|
|
677
|
-
if (
|
|
678
|
-
|
|
768
|
+
if (op->op == GGML_OP_LEAKY_RELU) {
|
|
769
|
+
args.slope = ggml_get_op_params_f32(op, 0);
|
|
679
770
|
}
|
|
680
771
|
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
686
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
687
|
-
|
|
688
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
689
|
-
|
|
690
|
-
return 1;
|
|
691
|
-
}
|
|
692
|
-
|
|
693
|
-
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
|
694
|
-
ggml_tensor * op = ctx->node(idx);
|
|
695
|
-
|
|
696
|
-
ggml_metal_library_t lib = ctx->lib;
|
|
697
|
-
ggml_metal_encoder_t enc = ctx->enc;
|
|
698
|
-
|
|
699
|
-
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
700
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
701
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
702
|
-
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
|
772
|
+
if (op->op == GGML_OP_SCALE) {
|
|
773
|
+
args.scale = ggml_get_op_params_f32(op, 0);
|
|
774
|
+
args.bias = ggml_get_op_params_f32(op, 1);
|
|
775
|
+
}
|
|
703
776
|
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
|
777
|
+
if (op->op == GGML_OP_FILL) {
|
|
778
|
+
args.val = ggml_get_op_params_f32(op, 0);
|
|
779
|
+
}
|
|
708
780
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
}
|
|
781
|
+
if (op->op == GGML_OP_CLAMP) {
|
|
782
|
+
args.min = ggml_get_op_params_f32(op, 0);
|
|
783
|
+
args.max = ggml_get_op_params_f32(op, 1);
|
|
784
|
+
}
|
|
713
785
|
|
|
714
|
-
|
|
786
|
+
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
715
787
|
|
|
716
|
-
if (
|
|
717
|
-
|
|
788
|
+
if (pipeline.c4) {
|
|
789
|
+
args.ne00 = ne00/4;
|
|
790
|
+
args.ne0 = ne0/4;
|
|
718
791
|
}
|
|
719
792
|
|
|
720
|
-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
721
|
-
|
|
722
793
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
723
794
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
724
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
725
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
726
|
-
|
|
727
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
728
|
-
|
|
729
|
-
return 1;
|
|
730
|
-
}
|
|
795
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
796
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
731
797
|
|
|
732
|
-
|
|
733
|
-
|
|
798
|
+
if (pipeline.cnt) {
|
|
799
|
+
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
|
734
800
|
|
|
735
|
-
|
|
736
|
-
|
|
801
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
802
|
+
} else {
|
|
803
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
737
804
|
|
|
738
|
-
|
|
739
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
740
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
741
|
-
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
|
805
|
+
const int nth = MIN(args.ne00, nth_max);
|
|
742
806
|
|
|
743
|
-
|
|
807
|
+
const int nk0 = (args.ne00 + nth - 1)/nth;
|
|
744
808
|
|
|
745
|
-
|
|
746
|
-
n /= 4;
|
|
809
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
|
|
747
810
|
}
|
|
748
811
|
|
|
749
|
-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
750
|
-
|
|
751
|
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
752
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
|
753
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
|
754
|
-
|
|
755
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
756
|
-
|
|
757
812
|
return 1;
|
|
758
813
|
}
|
|
759
814
|
|
|
@@ -768,13 +823,13 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
|
|
768
823
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
769
824
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
770
825
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
771
|
-
GGML_TENSOR_LOCALS(
|
|
826
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
772
827
|
|
|
773
828
|
if (op->src[1]) {
|
|
774
829
|
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
|
|
775
830
|
}
|
|
776
831
|
|
|
777
|
-
|
|
832
|
+
auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
|
|
778
833
|
|
|
779
834
|
const int32_t swp = ggml_get_op_params_i32(op, 1);
|
|
780
835
|
const float alpha = ggml_get_op_params_f32(op, 2);
|
|
@@ -800,18 +855,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
|
|
800
855
|
|
|
801
856
|
const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
|
|
802
857
|
|
|
803
|
-
//[encoder setComputePipelineState:pipeline];
|
|
804
|
-
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
805
|
-
//if (src1) {
|
|
806
|
-
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
807
|
-
//} else {
|
|
808
|
-
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
809
|
-
//}
|
|
810
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
811
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
812
|
-
|
|
813
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
814
|
-
|
|
815
858
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
816
859
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
817
860
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
@@ -827,6 +870,43 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
|
|
827
870
|
return 1;
|
|
828
871
|
}
|
|
829
872
|
|
|
873
|
+
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
|
874
|
+
ggml_tensor * op = ctx->node(idx);
|
|
875
|
+
|
|
876
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
877
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
878
|
+
|
|
879
|
+
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
|
|
880
|
+
|
|
881
|
+
ggml_metal_kargs_sum args = {
|
|
882
|
+
/*.np =*/ n,
|
|
883
|
+
};
|
|
884
|
+
|
|
885
|
+
auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
|
886
|
+
|
|
887
|
+
int nth = 32; // SIMD width
|
|
888
|
+
|
|
889
|
+
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
890
|
+
nth *= 2;
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
894
|
+
nth = std::min(nth, (int) n);
|
|
895
|
+
|
|
896
|
+
const int nsg = (nth + 31) / 32;
|
|
897
|
+
|
|
898
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
899
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
900
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
901
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
902
|
+
|
|
903
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
|
|
904
|
+
|
|
905
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
|
|
906
|
+
|
|
907
|
+
return 1;
|
|
908
|
+
}
|
|
909
|
+
|
|
830
910
|
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
831
911
|
ggml_tensor * op = ctx->node(idx);
|
|
832
912
|
|
|
@@ -836,7 +916,12 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
836
916
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
837
917
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
838
918
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
839
|
-
GGML_TENSOR_LOCALS(
|
|
919
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
920
|
+
|
|
921
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
922
|
+
|
|
923
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
924
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
840
925
|
|
|
841
926
|
ggml_metal_kargs_sum_rows args = {
|
|
842
927
|
/*.ne00 =*/ ne00,
|
|
@@ -857,31 +942,28 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
857
942
|
/*.nb3 =*/ nb3,
|
|
858
943
|
};
|
|
859
944
|
|
|
860
|
-
|
|
945
|
+
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
|
946
|
+
|
|
947
|
+
if (pipeline.c4) {
|
|
948
|
+
args.ne00 = ne00/4;
|
|
949
|
+
args.ne0 = ne0/4;
|
|
950
|
+
}
|
|
861
951
|
|
|
862
952
|
int nth = 32; // SIMD width
|
|
863
953
|
|
|
864
|
-
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
954
|
+
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
865
955
|
nth *= 2;
|
|
866
956
|
}
|
|
867
957
|
|
|
868
958
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
869
|
-
nth = std::min(nth, ne00);
|
|
870
|
-
|
|
871
|
-
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
|
959
|
+
nth = std::min(nth, (int) args.ne00);
|
|
872
960
|
|
|
873
|
-
|
|
874
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
875
|
-
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
876
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
877
|
-
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
878
|
-
|
|
879
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
961
|
+
const size_t smem = pipeline.smem;
|
|
880
962
|
|
|
881
963
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
882
964
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
883
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
884
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
965
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
966
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
885
967
|
|
|
886
968
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
887
969
|
|
|
@@ -890,6 +972,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
890
972
|
return 1;
|
|
891
973
|
}
|
|
892
974
|
|
|
975
|
+
int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
|
976
|
+
ggml_tensor * op = ctx->node(idx);
|
|
977
|
+
|
|
978
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
979
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
980
|
+
|
|
981
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
982
|
+
|
|
983
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
984
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
985
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
986
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
987
|
+
|
|
988
|
+
auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
|
989
|
+
|
|
990
|
+
int nth = 1;
|
|
991
|
+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
|
992
|
+
nth *= 2;
|
|
993
|
+
}
|
|
994
|
+
|
|
995
|
+
GGML_ASSERT(ne00 <= nth*nth);
|
|
996
|
+
|
|
997
|
+
const int64_t net0 = (ne00 + nth - 1) / nth;
|
|
998
|
+
const int64_t net1 = ne01;
|
|
999
|
+
const int64_t net2 = ne02;
|
|
1000
|
+
const int64_t net3 = ne03;
|
|
1001
|
+
|
|
1002
|
+
const uint64_t nbt0 = sizeof(float);
|
|
1003
|
+
const uint64_t nbt1 = net0*nbt0;
|
|
1004
|
+
const uint64_t nbt2 = net1*nbt1;
|
|
1005
|
+
const uint64_t nbt3 = net2*nbt2;
|
|
1006
|
+
|
|
1007
|
+
const size_t smem = GGML_PAD(32*sizeof(float), 16);
|
|
1008
|
+
|
|
1009
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
1010
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
1011
|
+
|
|
1012
|
+
ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
1013
|
+
bid_tmp.offs += ggml_nbytes(op);
|
|
1014
|
+
|
|
1015
|
+
{
|
|
1016
|
+
ggml_metal_kargs_cumsum_blk args = {
|
|
1017
|
+
/*.ne00 =*/ ne00,
|
|
1018
|
+
/*.ne01 =*/ ne01,
|
|
1019
|
+
/*.ne02 =*/ ne02,
|
|
1020
|
+
/*.ne03 =*/ ne03,
|
|
1021
|
+
/*.nb00 =*/ nb00,
|
|
1022
|
+
/*.nb01 =*/ nb01,
|
|
1023
|
+
/*.nb02 =*/ nb02,
|
|
1024
|
+
/*.nb03 =*/ nb03,
|
|
1025
|
+
/*.net0 =*/ net0,
|
|
1026
|
+
/*.net1 =*/ net1,
|
|
1027
|
+
/*.net2 =*/ net2,
|
|
1028
|
+
/*.net3 =*/ net3,
|
|
1029
|
+
/*.nbt0 =*/ nbt0,
|
|
1030
|
+
/*.nbt1 =*/ nbt1,
|
|
1031
|
+
/*.nbt2 =*/ nbt2,
|
|
1032
|
+
/*.nbt3 =*/ nbt3,
|
|
1033
|
+
/*.outb =*/ ne00 > nth,
|
|
1034
|
+
};
|
|
1035
|
+
|
|
1036
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
|
1037
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1038
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
1039
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
|
1040
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
|
1041
|
+
|
|
1042
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1043
|
+
|
|
1044
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
|
1045
|
+
}
|
|
1046
|
+
|
|
1047
|
+
if (ne00 > nth) {
|
|
1048
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
1049
|
+
|
|
1050
|
+
{
|
|
1051
|
+
ggml_metal_kargs_cumsum_blk args = {
|
|
1052
|
+
/*.ne00 =*/ net0,
|
|
1053
|
+
/*.ne01 =*/ net1,
|
|
1054
|
+
/*.ne02 =*/ net2,
|
|
1055
|
+
/*.ne03 =*/ net3,
|
|
1056
|
+
/*.nb00 =*/ nbt0,
|
|
1057
|
+
/*.nb01 =*/ nbt1,
|
|
1058
|
+
/*.nb02 =*/ nbt2,
|
|
1059
|
+
/*.nb03 =*/ nbt3,
|
|
1060
|
+
/*.net0 =*/ net0,
|
|
1061
|
+
/*.net1 =*/ net1,
|
|
1062
|
+
/*.net2 =*/ net2,
|
|
1063
|
+
/*.net3 =*/ net3,
|
|
1064
|
+
/*.nbt0 =*/ nbt0,
|
|
1065
|
+
/*.nbt1 =*/ nbt1,
|
|
1066
|
+
/*.nbt2 =*/ nbt2,
|
|
1067
|
+
/*.nbt3 =*/ nbt3,
|
|
1068
|
+
/*.outb =*/ false,
|
|
1069
|
+
};
|
|
1070
|
+
|
|
1071
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
|
1072
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1073
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
|
1074
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
|
1075
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
1076
|
+
|
|
1077
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1078
|
+
|
|
1079
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
|
|
1080
|
+
}
|
|
1081
|
+
|
|
1082
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
1083
|
+
|
|
1084
|
+
{
|
|
1085
|
+
auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
|
1086
|
+
|
|
1087
|
+
ggml_metal_kargs_cumsum_add args = {
|
|
1088
|
+
/*.ne00 =*/ ne00,
|
|
1089
|
+
/*.ne01 =*/ ne01,
|
|
1090
|
+
/*.ne02 =*/ ne02,
|
|
1091
|
+
/*.ne03 =*/ ne03,
|
|
1092
|
+
/*.nb00 =*/ nb00,
|
|
1093
|
+
/*.nb01 =*/ nb01,
|
|
1094
|
+
/*.nb02 =*/ nb02,
|
|
1095
|
+
/*.nb03 =*/ nb03,
|
|
1096
|
+
/*.net0 =*/ net0,
|
|
1097
|
+
/*.net1 =*/ net1,
|
|
1098
|
+
/*.net2 =*/ net2,
|
|
1099
|
+
/*.net3 =*/ net3,
|
|
1100
|
+
/*.nbt0 =*/ nbt0,
|
|
1101
|
+
/*.nbt1 =*/ nbt1,
|
|
1102
|
+
/*.nbt2 =*/ nbt2,
|
|
1103
|
+
/*.nbt3 =*/ nbt3,
|
|
1104
|
+
};
|
|
1105
|
+
|
|
1106
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_add);
|
|
1107
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1108
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
|
1109
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1110
|
+
|
|
1111
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
|
1112
|
+
}
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
return 1;
|
|
1116
|
+
}
|
|
1117
|
+
|
|
893
1118
|
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
|
894
1119
|
ggml_tensor * op = ctx->node(idx);
|
|
895
1120
|
|
|
@@ -901,28 +1126,36 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
901
1126
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
902
1127
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
903
1128
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
904
|
-
GGML_TENSOR_LOCALS(
|
|
1129
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
905
1130
|
|
|
906
|
-
|
|
1131
|
+
auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
|
907
1132
|
|
|
908
1133
|
ggml_metal_kargs_get_rows args = {
|
|
909
|
-
/*.
|
|
910
|
-
/*.
|
|
911
|
-
/*.
|
|
912
|
-
/*.
|
|
913
|
-
/*.
|
|
914
|
-
/*.
|
|
915
|
-
/*.
|
|
916
|
-
/*.
|
|
1134
|
+
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
|
1135
|
+
/*.ne00 =*/ ne00,
|
|
1136
|
+
/*.nb01 =*/ nb01,
|
|
1137
|
+
/*.nb02 =*/ nb02,
|
|
1138
|
+
/*.nb03 =*/ nb03,
|
|
1139
|
+
/*.ne10 =*/ ne10,
|
|
1140
|
+
/*.nb10 =*/ nb10,
|
|
1141
|
+
/*.nb11 =*/ nb11,
|
|
1142
|
+
/*.nb12 =*/ nb12,
|
|
1143
|
+
/*.nb1 =*/ nb1,
|
|
1144
|
+
/*.nb2 =*/ nb2,
|
|
1145
|
+
/*.nb3 =*/ nb3,
|
|
917
1146
|
};
|
|
918
1147
|
|
|
1148
|
+
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1149
|
+
|
|
1150
|
+
const int nw0 = (args.ne00t + nth - 1)/nth;
|
|
1151
|
+
|
|
919
1152
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
920
1153
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
921
1154
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
922
1155
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
923
1156
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
924
1157
|
|
|
925
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12,
|
|
1158
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
|
|
926
1159
|
|
|
927
1160
|
return 1;
|
|
928
1161
|
}
|
|
@@ -938,9 +1171,9 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
938
1171
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
939
1172
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
940
1173
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
941
|
-
GGML_TENSOR_LOCALS(
|
|
1174
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
942
1175
|
|
|
943
|
-
|
|
1176
|
+
auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
|
944
1177
|
|
|
945
1178
|
const int32_t nk0 = ne0/ggml_blck_size(op->type);
|
|
946
1179
|
|
|
@@ -989,6 +1222,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
989
1222
|
return 1;
|
|
990
1223
|
}
|
|
991
1224
|
|
|
1225
|
+
int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
|
|
1226
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1227
|
+
|
|
1228
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1229
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1230
|
+
|
|
1231
|
+
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
|
1232
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1233
|
+
GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
|
|
1234
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1235
|
+
|
|
1236
|
+
ggml_metal_kargs_diag args = {
|
|
1237
|
+
/*.ne00 =*/ne00,
|
|
1238
|
+
/*.ne01 =*/ne01,
|
|
1239
|
+
/*.ne02 =*/ne02,
|
|
1240
|
+
/*.ne03 =*/ne03,
|
|
1241
|
+
/*.nb00 =*/nb00,
|
|
1242
|
+
/*.nb01 =*/nb01,
|
|
1243
|
+
/*.nb02 =*/nb02,
|
|
1244
|
+
/*.nb03 =*/nb03,
|
|
1245
|
+
/*.ne0 =*/ne0,
|
|
1246
|
+
/*.ne1 =*/ne1,
|
|
1247
|
+
/*.ne2 =*/ne2,
|
|
1248
|
+
/*.ne3 =*/ne3,
|
|
1249
|
+
/*.nb0 =*/nb0,
|
|
1250
|
+
/*.nb1 =*/nb1,
|
|
1251
|
+
/*.nb2 =*/nb2,
|
|
1252
|
+
/*.nb3 =*/nb3,
|
|
1253
|
+
};
|
|
1254
|
+
|
|
1255
|
+
auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
|
|
1256
|
+
|
|
1257
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1258
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1259
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1260
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
|
|
1261
|
+
|
|
1262
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
|
|
1263
|
+
|
|
1264
|
+
return 1;
|
|
1265
|
+
}
|
|
1266
|
+
|
|
992
1267
|
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
993
1268
|
ggml_tensor * op = ctx->node(idx);
|
|
994
1269
|
|
|
@@ -1002,7 +1277,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
|
1002
1277
|
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1003
1278
|
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1004
1279
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1005
|
-
GGML_TENSOR_LOCALS(
|
|
1280
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1006
1281
|
|
|
1007
1282
|
float scale;
|
|
1008
1283
|
float max_bias;
|
|
@@ -1041,7 +1316,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
|
1041
1316
|
/*.n_head_log2 =*/ n_head_log2,
|
|
1042
1317
|
};
|
|
1043
1318
|
|
|
1044
|
-
|
|
1319
|
+
auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
|
|
1045
1320
|
|
|
1046
1321
|
int nth = 32; // SIMD width
|
|
1047
1322
|
|
|
@@ -1055,7 +1330,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
|
1055
1330
|
}
|
|
1056
1331
|
}
|
|
1057
1332
|
|
|
1058
|
-
const size_t smem =
|
|
1333
|
+
const size_t smem = pipeline.smem;
|
|
1059
1334
|
|
|
1060
1335
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1061
1336
|
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
@@ -1090,7 +1365,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
|
|
1090
1365
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1091
1366
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1092
1367
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1093
|
-
GGML_TENSOR_LOCALS(
|
|
1368
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1094
1369
|
|
|
1095
1370
|
ggml_metal_kargs_ssm_conv args = {
|
|
1096
1371
|
/*.ne00 =*/ ne00,
|
|
@@ -1111,18 +1386,46 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
|
|
1111
1386
|
/*.nb2 =*/ nb2,
|
|
1112
1387
|
};
|
|
1113
1388
|
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1117
|
-
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1118
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1119
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1120
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
|
1389
|
+
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
|
|
1390
|
+
const bool use_batched = (ne1 > 1);
|
|
1121
1391
|
|
|
1122
|
-
|
|
1392
|
+
if (use_batched) {
|
|
1393
|
+
// Determine the smallest power of 2 that's >= ne1, but <= 256
|
|
1394
|
+
int BATCH_SIZE;
|
|
1395
|
+
if (ne1 > 128) BATCH_SIZE = 256;
|
|
1396
|
+
else if (ne1 > 64 ) BATCH_SIZE = 128;
|
|
1397
|
+
else if (ne1 > 32 ) BATCH_SIZE = 64;
|
|
1398
|
+
else if (ne1 > 16 ) BATCH_SIZE = 32;
|
|
1399
|
+
else if (ne1 > 8 ) BATCH_SIZE = 16;
|
|
1400
|
+
else if (ne1 > 4 ) BATCH_SIZE = 8;
|
|
1401
|
+
else BATCH_SIZE = 2;
|
|
1123
1402
|
|
|
1124
|
-
|
|
1125
|
-
|
|
1403
|
+
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
|
|
1404
|
+
|
|
1405
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1406
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1407
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1408
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1409
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
|
1410
|
+
|
|
1411
|
+
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
|
|
1412
|
+
// Each threadgroup has BATCH_SIZE threads, each handling one token
|
|
1413
|
+
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
|
|
1414
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
|
|
1415
|
+
} else {
|
|
1416
|
+
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
|
1417
|
+
|
|
1418
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1419
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1420
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1421
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1422
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
|
1423
|
+
|
|
1424
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1427
|
+
return 1;
|
|
1428
|
+
}
|
|
1126
1429
|
|
|
1127
1430
|
int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|
1128
1431
|
ggml_tensor * op = ctx->node(idx);
|
|
@@ -1145,7 +1448,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|
|
1145
1448
|
GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
|
|
1146
1449
|
GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
|
|
1147
1450
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1148
|
-
GGML_TENSOR_LOCALS(
|
|
1451
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1149
1452
|
|
|
1150
1453
|
const ggml_tensor * src3 = op->src[3];
|
|
1151
1454
|
const ggml_tensor * src4 = op->src[4];
|
|
@@ -1172,26 +1475,37 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|
|
1172
1475
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
1173
1476
|
/*.n_seqs =*/ n_seqs,
|
|
1174
1477
|
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
|
|
1478
|
+
/*.nb00 =*/ nb00,
|
|
1175
1479
|
/*.nb01 =*/ nb01,
|
|
1176
1480
|
/*.nb02 =*/ nb02,
|
|
1177
1481
|
/*.nb03 =*/ nb03,
|
|
1482
|
+
/*.nb10 =*/ nb10,
|
|
1178
1483
|
/*.nb11 =*/ nb11,
|
|
1179
1484
|
/*.nb12 =*/ nb12,
|
|
1485
|
+
/*.ns12 =*/ nb12/nb10,
|
|
1180
1486
|
/*.nb13 =*/ nb13,
|
|
1487
|
+
/*.nb20 =*/ nb20,
|
|
1181
1488
|
/*.nb21 =*/ nb21,
|
|
1489
|
+
/*.ns21 =*/ nb21/nb20,
|
|
1182
1490
|
/*.nb22 =*/ nb22,
|
|
1491
|
+
/*.ne30 =*/ ne30,
|
|
1183
1492
|
/*.nb31 =*/ nb31,
|
|
1184
1493
|
/*.nb41 =*/ nb41,
|
|
1185
1494
|
/*.nb42 =*/ nb42,
|
|
1495
|
+
/*.ns42 =*/ nb42/nb40,
|
|
1186
1496
|
/*.nb43 =*/ nb43,
|
|
1187
1497
|
/*.nb51 =*/ nb51,
|
|
1188
1498
|
/*.nb52 =*/ nb52,
|
|
1499
|
+
/*.ns52 =*/ nb52/nb50,
|
|
1189
1500
|
/*.nb53 =*/ nb53,
|
|
1501
|
+
/*.nb0 =*/ nb0,
|
|
1190
1502
|
};
|
|
1191
1503
|
|
|
1192
|
-
|
|
1504
|
+
auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
|
1193
1505
|
|
|
1194
|
-
|
|
1506
|
+
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1507
|
+
|
|
1508
|
+
const size_t smem = pipeline.smem;
|
|
1195
1509
|
|
|
1196
1510
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1197
1511
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -1204,15 +1518,9 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|
|
1204
1518
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
|
|
1205
1519
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
|
|
1206
1520
|
|
|
1207
|
-
ggml_metal_encoder_set_threadgroup_memory_size(enc,
|
|
1521
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1208
1522
|
|
|
1209
|
-
|
|
1210
|
-
// Mamba-2
|
|
1211
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1212
|
-
} else {
|
|
1213
|
-
GGML_ASSERT(d_inner == 1);
|
|
1214
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
|
|
1215
|
-
}
|
|
1523
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1216
1524
|
|
|
1217
1525
|
return 1;
|
|
1218
1526
|
}
|
|
@@ -1226,14 +1534,14 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|
|
1226
1534
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1227
1535
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1228
1536
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1229
|
-
GGML_TENSOR_LOCALS(
|
|
1537
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1230
1538
|
|
|
1231
1539
|
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
|
1232
1540
|
const int64_t T = op->src[0]->ne[2];
|
|
1233
1541
|
const int64_t C = op->ne[0];
|
|
1234
1542
|
const int64_t H = op->src[0]->ne[1];
|
|
1235
1543
|
|
|
1236
|
-
|
|
1544
|
+
auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
|
1237
1545
|
|
|
1238
1546
|
int ida = 0;
|
|
1239
1547
|
|
|
@@ -1258,41 +1566,298 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|
|
1258
1566
|
return 1;
|
|
1259
1567
|
}
|
|
1260
1568
|
|
|
1261
|
-
int
|
|
1569
|
+
int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
|
|
1262
1570
|
ggml_tensor * op = ctx->node(idx);
|
|
1263
1571
|
|
|
1264
1572
|
ggml_metal_library_t lib = ctx->lib;
|
|
1265
1573
|
ggml_metal_encoder_t enc = ctx->enc;
|
|
1266
1574
|
|
|
1575
|
+
|
|
1267
1576
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1268
1577
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1578
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1579
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1580
|
+
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1581
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1269
1582
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1270
|
-
GGML_TENSOR_LOCALS(
|
|
1583
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1271
1584
|
|
|
1272
|
-
|
|
1585
|
+
auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
|
|
1273
1586
|
|
|
1274
|
-
|
|
1587
|
+
int ida = 0;
|
|
1588
|
+
|
|
1589
|
+
ggml_metal_kargs_gated_delta_net args = {
|
|
1590
|
+
/*.ne00 =*/ ne00,
|
|
1591
|
+
/*.ne01 =*/ ne01,
|
|
1592
|
+
/*.ne02 =*/ ne02,
|
|
1593
|
+
/*.ne03 =*/ ne03,
|
|
1594
|
+
/*.nb00 =*/ nb00,
|
|
1595
|
+
/*.nb01 =*/ nb01,
|
|
1596
|
+
/*.nb02 =*/ nb02,
|
|
1597
|
+
/*.nb03 =*/ nb03,
|
|
1598
|
+
/*.ne10 =*/ ne10,
|
|
1599
|
+
/*.ne11 =*/ ne11,
|
|
1600
|
+
/*.ne12 =*/ ne12,
|
|
1601
|
+
/*.ne13 =*/ ne13,
|
|
1602
|
+
/*.nb10 =*/ nb10,
|
|
1603
|
+
/*.nb11 =*/ nb11,
|
|
1604
|
+
/*.nb12 =*/ nb12,
|
|
1605
|
+
/*.nb13 =*/ nb13,
|
|
1606
|
+
/*.ne20 =*/ ne20,
|
|
1607
|
+
/*.ne21 =*/ ne21,
|
|
1608
|
+
/*.ne22 =*/ ne22,
|
|
1609
|
+
/*.ne23 =*/ ne23,
|
|
1610
|
+
/*.nb20 =*/ nb20,
|
|
1611
|
+
/*.nb21 =*/ nb21,
|
|
1612
|
+
/*.nb22 =*/ nb22,
|
|
1613
|
+
/*.nb23 =*/ nb23,
|
|
1614
|
+
/*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
|
|
1615
|
+
/*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
|
|
1616
|
+
/*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
|
|
1617
|
+
/*.ne0 =*/ ne0,
|
|
1618
|
+
/*.ne1 =*/ ne1,
|
|
1619
|
+
/*.ne2 =*/ ne2,
|
|
1620
|
+
/*.ne3 =*/ ne3,
|
|
1621
|
+
/*.nb0 =*/ nb0,
|
|
1622
|
+
/*.nb1 =*/ nb1,
|
|
1623
|
+
/*.nb2 =*/ nb2,
|
|
1624
|
+
/*.nb3 =*/ nb3,
|
|
1625
|
+
};
|
|
1275
1626
|
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1627
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1628
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
1629
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
|
|
1630
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
|
|
1631
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
|
|
1632
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
|
|
1633
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
|
|
1634
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
|
|
1635
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
|
|
1279
1636
|
|
|
1280
|
-
int
|
|
1637
|
+
const int nsg = pipeline.nsg;
|
|
1281
1638
|
|
|
1282
|
-
|
|
1283
|
-
|
|
1639
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
|
|
1640
|
+
|
|
1641
|
+
return 1;
|
|
1642
|
+
}
|
|
1643
|
+
|
|
1644
|
+
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
|
1645
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1646
|
+
|
|
1647
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1648
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1649
|
+
|
|
1650
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1651
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1652
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1653
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1654
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1655
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1656
|
+
|
|
1657
|
+
ggml_metal_kargs_solve_tri args = {
|
|
1658
|
+
/*.ne00 =*/ ne00,
|
|
1659
|
+
/*.ne01 =*/ ne01,
|
|
1660
|
+
/*.ne02 =*/ ne02,
|
|
1661
|
+
/*.ne03 =*/ ne03,
|
|
1662
|
+
/*.nb00 =*/ nb00,
|
|
1663
|
+
/*.nb01 =*/ nb01,
|
|
1664
|
+
/*.nb02 =*/ nb02,
|
|
1665
|
+
/*.nb03 =*/ nb03,
|
|
1666
|
+
/*.ne10 =*/ ne10,
|
|
1667
|
+
/*.ne11 =*/ ne11,
|
|
1668
|
+
/*.ne12 =*/ ne12,
|
|
1669
|
+
/*.ne13 =*/ ne13,
|
|
1670
|
+
/*.nb10 =*/ nb10,
|
|
1671
|
+
/*.nb11 =*/ nb11,
|
|
1672
|
+
/*.nb12 =*/ nb12,
|
|
1673
|
+
/*.nb13 =*/ nb13,
|
|
1674
|
+
/*.ne0 =*/ ne0,
|
|
1675
|
+
/*.ne1 =*/ ne1,
|
|
1676
|
+
/*.ne2 =*/ ne2,
|
|
1677
|
+
/*.ne3 =*/ ne3,
|
|
1678
|
+
/*.nb0 =*/ nb0,
|
|
1679
|
+
/*.nb1 =*/ nb1,
|
|
1680
|
+
/*.nb2 =*/ nb2,
|
|
1681
|
+
/*.nb3 =*/ nb3,
|
|
1682
|
+
};
|
|
1683
|
+
|
|
1684
|
+
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
|
1685
|
+
|
|
1686
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1687
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1688
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1689
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1690
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
1691
|
+
|
|
1692
|
+
const int nsg = pipeline.nsg;
|
|
1693
|
+
|
|
1694
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
|
|
1695
|
+
|
|
1696
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
|
|
1697
|
+
|
|
1698
|
+
return 1;
|
|
1699
|
+
}
|
|
1700
|
+
|
|
1701
|
+
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
|
1702
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1703
|
+
|
|
1704
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1705
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1706
|
+
|
|
1707
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1708
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1709
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1710
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1711
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1712
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1713
|
+
|
|
1714
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
1715
|
+
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
1716
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
1717
|
+
|
|
1718
|
+
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
|
1719
|
+
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
|
1720
|
+
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
|
|
1721
|
+
const size_t offs = ((const int32_t *) op->op_params)[3];
|
|
1722
|
+
|
|
1723
|
+
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
|
1724
|
+
|
|
1725
|
+
if (!inplace) {
|
|
1726
|
+
// run a separate kernel to cpy src->dst
|
|
1727
|
+
// not sure how to avoid this
|
|
1728
|
+
// TODO: make a simpler cpy_bytes kernel
|
|
1729
|
+
|
|
1730
|
+
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
|
1731
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
1732
|
+
|
|
1733
|
+
ggml_metal_kargs_cpy args = {
|
|
1734
|
+
/*.nk0 =*/ ne00,
|
|
1735
|
+
/*.ne00 =*/ ne00,
|
|
1736
|
+
/*.ne01 =*/ ne01,
|
|
1737
|
+
/*.ne02 =*/ ne02,
|
|
1738
|
+
/*.ne03 =*/ ne03,
|
|
1739
|
+
/*.nb00 =*/ nb00,
|
|
1740
|
+
/*.nb01 =*/ nb01,
|
|
1741
|
+
/*.nb02 =*/ nb02,
|
|
1742
|
+
/*.nb03 =*/ nb03,
|
|
1743
|
+
/*.ne0 =*/ ne0,
|
|
1744
|
+
/*.ne1 =*/ ne1,
|
|
1745
|
+
/*.ne2 =*/ ne2,
|
|
1746
|
+
/*.ne3 =*/ ne3,
|
|
1747
|
+
/*.nb0 =*/ nb0,
|
|
1748
|
+
/*.nb1 =*/ nb1,
|
|
1749
|
+
/*.nb2 =*/ nb2,
|
|
1750
|
+
/*.nb3 =*/ nb3,
|
|
1751
|
+
};
|
|
1752
|
+
|
|
1753
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1754
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1755
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
1756
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1757
|
+
|
|
1758
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
|
1759
|
+
|
|
1760
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
1761
|
+
|
|
1762
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
1284
1763
|
}
|
|
1285
1764
|
|
|
1286
|
-
|
|
1765
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
|
|
1766
|
+
|
|
1767
|
+
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
|
|
1768
|
+
|
|
1769
|
+
int64_t nk0 = ne10;
|
|
1770
|
+
if (ggml_is_quantized(op->src[1]->type)) {
|
|
1771
|
+
nk0 = ne10/16;
|
|
1772
|
+
} else if (ggml_is_quantized(op->type)) {
|
|
1773
|
+
nk0 = ne10/ggml_blck_size(op->type);
|
|
1774
|
+
}
|
|
1775
|
+
|
|
1776
|
+
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1777
|
+
|
|
1778
|
+
// when rows are small, we can batch them together in a single threadgroup
|
|
1779
|
+
int nrptg = 1;
|
|
1780
|
+
|
|
1781
|
+
// TODO: relax this constraint in the future
|
|
1782
|
+
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
|
1783
|
+
if (nth > nk0) {
|
|
1784
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
1785
|
+
nth = nk0;
|
|
1786
|
+
|
|
1787
|
+
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1788
|
+
nrptg--;
|
|
1789
|
+
}
|
|
1790
|
+
}
|
|
1791
|
+
}
|
|
1792
|
+
|
|
1793
|
+
nth = std::min<int>(nth, nk0);
|
|
1794
|
+
|
|
1795
|
+
ggml_metal_kargs_cpy args = {
|
|
1796
|
+
/*.nk0 =*/ nk0,
|
|
1797
|
+
/*.ne00 =*/ ne10,
|
|
1798
|
+
/*.ne01 =*/ ne11,
|
|
1799
|
+
/*.ne02 =*/ ne12,
|
|
1800
|
+
/*.ne03 =*/ ne13,
|
|
1801
|
+
/*.nb00 =*/ nb10,
|
|
1802
|
+
/*.nb01 =*/ nb11,
|
|
1803
|
+
/*.nb02 =*/ nb12,
|
|
1804
|
+
/*.nb03 =*/ nb13,
|
|
1805
|
+
/*.ne0 =*/ ne10,
|
|
1806
|
+
/*.ne1 =*/ ne11,
|
|
1807
|
+
/*.ne2 =*/ ne12,
|
|
1808
|
+
/*.ne3 =*/ ne13,
|
|
1809
|
+
/*.nb0 =*/ ggml_element_size(op),
|
|
1810
|
+
/*.nb1 =*/ pnb1,
|
|
1811
|
+
/*.nb2 =*/ pnb2,
|
|
1812
|
+
/*.nb3 =*/ pnb3,
|
|
1813
|
+
};
|
|
1814
|
+
|
|
1815
|
+
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
|
1816
|
+
|
|
1817
|
+
bid_dst.offs += offs;
|
|
1818
|
+
|
|
1819
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1820
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1821
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
1822
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1823
|
+
|
|
1824
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
|
|
1825
|
+
|
|
1826
|
+
return 1;
|
|
1827
|
+
}
|
|
1828
|
+
|
|
1829
|
+
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
1830
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1831
|
+
|
|
1832
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1833
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1834
|
+
|
|
1835
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1836
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1837
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1838
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1839
|
+
|
|
1840
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
1841
|
+
|
|
1842
|
+
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
|
|
1843
|
+
|
|
1844
|
+
int64_t nk0 = ne00;
|
|
1845
|
+
if (ggml_is_quantized(op->src[0]->type)) {
|
|
1846
|
+
nk0 = ne00/16;
|
|
1847
|
+
} else if (ggml_is_quantized(op->type)) {
|
|
1848
|
+
nk0 = ne00/ggml_blck_size(op->type);
|
|
1849
|
+
}
|
|
1850
|
+
|
|
1851
|
+
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1287
1852
|
|
|
1288
1853
|
// when rows are small, we can batch them together in a single threadgroup
|
|
1289
1854
|
int nrptg = 1;
|
|
1290
1855
|
|
|
1291
1856
|
// TODO: relax this constraint in the future
|
|
1292
1857
|
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
|
1293
|
-
if (nth >
|
|
1294
|
-
nrptg = (nth +
|
|
1295
|
-
nth =
|
|
1858
|
+
if (nth > nk0) {
|
|
1859
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
1860
|
+
nth = nk0;
|
|
1296
1861
|
|
|
1297
1862
|
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1298
1863
|
nrptg--;
|
|
@@ -1300,10 +1865,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1300
1865
|
}
|
|
1301
1866
|
}
|
|
1302
1867
|
|
|
1303
|
-
nth = std::min(nth,
|
|
1868
|
+
nth = std::min<int>(nth, nk0);
|
|
1304
1869
|
|
|
1305
1870
|
ggml_metal_kargs_cpy args = {
|
|
1306
|
-
/*.
|
|
1871
|
+
/*.nk0 =*/ nk0,
|
|
1872
|
+
/*.ne00 =*/ ne00,
|
|
1307
1873
|
/*.ne01 =*/ ne01,
|
|
1308
1874
|
/*.ne02 =*/ ne02,
|
|
1309
1875
|
/*.ne03 =*/ ne03,
|
|
@@ -1321,16 +1887,66 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1321
1887
|
/*.nb3 =*/ nb3,
|
|
1322
1888
|
};
|
|
1323
1889
|
|
|
1890
|
+
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
|
1891
|
+
|
|
1324
1892
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1325
1893
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1326
1894
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1327
1895
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
1328
1896
|
|
|
1329
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
|
|
1897
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
|
|
1898
|
+
|
|
1899
|
+
return 1;
|
|
1900
|
+
}
|
|
1901
|
+
|
|
1902
|
+
int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
|
|
1903
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1904
|
+
|
|
1905
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1906
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1907
|
+
|
|
1908
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1909
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1910
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1911
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1912
|
+
|
|
1913
|
+
const int32_t * opts = op->op_params;
|
|
1914
|
+
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
|
1915
|
+
|
|
1916
|
+
const int32_t k0 = opts[1];
|
|
1917
|
+
const int32_t s0 = opts[2];
|
|
1918
|
+
const int32_t p0 = opts[3];
|
|
1919
|
+
|
|
1920
|
+
const int64_t IW = op->src[0]->ne[0];
|
|
1921
|
+
const int64_t OW = op->ne[0];
|
|
1922
|
+
|
|
1923
|
+
const int64_t np = ggml_nelements(op);
|
|
1924
|
+
|
|
1925
|
+
ggml_metal_kargs_pool_1d args_pool_1d = {
|
|
1926
|
+
/* .k0 = */ k0,
|
|
1927
|
+
/* .s0 = */ s0,
|
|
1928
|
+
/* .p0 = */ p0,
|
|
1929
|
+
/* .IW = */ IW,
|
|
1930
|
+
/* .OW = */ OW,
|
|
1931
|
+
/* .np = */ np
|
|
1932
|
+
};
|
|
1933
|
+
|
|
1934
|
+
auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
|
|
1935
|
+
|
|
1936
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
|
1937
|
+
const int ntg = (np + nth - 1) / nth;
|
|
1938
|
+
|
|
1939
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1940
|
+
ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
|
|
1941
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1942
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
1943
|
+
|
|
1944
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
|
|
1330
1945
|
|
|
1331
1946
|
return 1;
|
|
1332
1947
|
}
|
|
1333
1948
|
|
|
1949
|
+
|
|
1334
1950
|
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
|
1335
1951
|
ggml_tensor * op = ctx->node(idx);
|
|
1336
1952
|
|
|
@@ -1340,7 +1956,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
|
|
1340
1956
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1341
1957
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1342
1958
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1343
|
-
GGML_TENSOR_LOCALS(
|
|
1959
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1344
1960
|
|
|
1345
1961
|
const int32_t * opts = op->op_params;
|
|
1346
1962
|
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
|
@@ -1376,7 +1992,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
|
|
1376
1992
|
/* .np = */ np
|
|
1377
1993
|
};
|
|
1378
1994
|
|
|
1379
|
-
|
|
1995
|
+
auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
|
|
1380
1996
|
|
|
1381
1997
|
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
|
1382
1998
|
const int ntg = (np + nth - 1) / nth;
|
|
@@ -1404,7 +2020,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1404
2020
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1405
2021
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1406
2022
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1407
|
-
GGML_TENSOR_LOCALS(
|
|
2023
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1408
2024
|
|
|
1409
2025
|
GGML_ASSERT(ne00 == ne10);
|
|
1410
2026
|
|
|
@@ -1426,6 +2042,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1426
2042
|
(
|
|
1427
2043
|
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
|
1428
2044
|
op->src[0]->type == GGML_TYPE_F16 ||
|
|
2045
|
+
op->src[0]->type == GGML_TYPE_BF16 ||
|
|
1429
2046
|
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
|
1430
2047
|
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
|
1431
2048
|
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
|
@@ -1440,6 +2057,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1440
2057
|
op->src[0]->type == GGML_TYPE_Q4_K ||
|
|
1441
2058
|
op->src[0]->type == GGML_TYPE_Q5_K ||
|
|
1442
2059
|
op->src[0]->type == GGML_TYPE_Q6_K ||
|
|
2060
|
+
op->src[0]->type == GGML_TYPE_Q2_K ||
|
|
2061
|
+
op->src[0]->type == GGML_TYPE_Q3_K ||
|
|
1443
2062
|
false) && (ne11 >= 4 && ne11 <= 8)
|
|
1444
2063
|
)
|
|
1445
2064
|
)
|
|
@@ -1468,7 +2087,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1468
2087
|
const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
|
1469
2088
|
int16_t r1ptg = 4; // num src1 rows per threadgroup
|
|
1470
2089
|
|
|
1471
|
-
// note: not sure how optimal are those across all different hardware. there might be
|
|
2090
|
+
// note: not sure how optimal are those across all different hardware. there might be something cleverer
|
|
1472
2091
|
switch (ne11) {
|
|
1473
2092
|
case 2:
|
|
1474
2093
|
r1ptg = 2; break;
|
|
@@ -1485,7 +2104,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1485
2104
|
GGML_ABORT("unsupported ne11");
|
|
1486
2105
|
};
|
|
1487
2106
|
|
|
1488
|
-
|
|
2107
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
|
|
1489
2108
|
|
|
1490
2109
|
ggml_metal_kargs_mul_mv_ext args = {
|
|
1491
2110
|
/*.ne00 =*/ ne00,
|
|
@@ -1520,9 +2139,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1520
2139
|
!ggml_is_transposed(op->src[1]) &&
|
|
1521
2140
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1522
2141
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1523
|
-
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
|
1524
|
-
(
|
|
1525
|
-
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
2142
|
+
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
|
|
2143
|
+
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1526
2144
|
|
|
1527
2145
|
// some Metal matrix data types require aligned pointers
|
|
1528
2146
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
@@ -1533,7 +2151,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1533
2151
|
// default: break;
|
|
1534
2152
|
//}
|
|
1535
2153
|
|
|
1536
|
-
|
|
2154
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
|
1537
2155
|
|
|
1538
2156
|
ggml_metal_kargs_mul_mm args = {
|
|
1539
2157
|
/*.ne00 =*/ ne00,
|
|
@@ -1558,18 +2176,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1558
2176
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1559
2177
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
1560
2178
|
|
|
1561
|
-
const size_t smem =
|
|
2179
|
+
const size_t smem = pipeline.smem;
|
|
1562
2180
|
|
|
1563
2181
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1564
2182
|
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
|
|
1565
2183
|
} else {
|
|
1566
|
-
|
|
2184
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
|
1567
2185
|
|
|
1568
|
-
const int nr0 =
|
|
1569
|
-
const int nr1 =
|
|
1570
|
-
const int nsg =
|
|
2186
|
+
const int nr0 = pipeline.nr0;
|
|
2187
|
+
const int nr1 = pipeline.nr1;
|
|
2188
|
+
const int nsg = pipeline.nsg;
|
|
1571
2189
|
|
|
1572
|
-
const size_t smem =
|
|
2190
|
+
const size_t smem = pipeline.smem;
|
|
1573
2191
|
|
|
1574
2192
|
ggml_metal_kargs_mul_mv args = {
|
|
1575
2193
|
/*.ne00 =*/ ne00,
|
|
@@ -1646,7 +2264,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1646
2264
|
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1647
2265
|
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1648
2266
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1649
|
-
GGML_TENSOR_LOCALS(
|
|
2267
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1650
2268
|
|
|
1651
2269
|
// src2 = ids
|
|
1652
2270
|
GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
|
|
@@ -1700,9 +2318,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1700
2318
|
nb21,
|
|
1701
2319
|
};
|
|
1702
2320
|
|
|
1703
|
-
|
|
2321
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
|
|
1704
2322
|
|
|
1705
|
-
const size_t smem =
|
|
2323
|
+
const size_t smem = pipeline.smem;
|
|
1706
2324
|
|
|
1707
2325
|
GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1708
2326
|
|
|
@@ -1723,7 +2341,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1723
2341
|
ggml_metal_op_concurrency_reset(ctx);
|
|
1724
2342
|
|
|
1725
2343
|
{
|
|
1726
|
-
|
|
2344
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
|
1727
2345
|
|
|
1728
2346
|
ggml_metal_kargs_mul_mm_id args = {
|
|
1729
2347
|
/*.ne00 =*/ ne00,
|
|
@@ -1752,20 +2370,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1752
2370
|
ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
|
|
1753
2371
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
|
|
1754
2372
|
|
|
1755
|
-
const size_t smem =
|
|
2373
|
+
const size_t smem = pipeline.smem;
|
|
1756
2374
|
|
|
1757
2375
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1758
2376
|
|
|
1759
2377
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
|
|
1760
2378
|
}
|
|
1761
2379
|
} else {
|
|
1762
|
-
|
|
2380
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
|
1763
2381
|
|
|
1764
|
-
const int nr0 =
|
|
1765
|
-
const int nr1 =
|
|
1766
|
-
const int nsg =
|
|
2382
|
+
const int nr0 = pipeline.nr0;
|
|
2383
|
+
const int nr1 = pipeline.nr1;
|
|
2384
|
+
const int nsg = pipeline.nsg;
|
|
1767
2385
|
|
|
1768
|
-
const size_t smem =
|
|
2386
|
+
const size_t smem = pipeline.smem;
|
|
1769
2387
|
|
|
1770
2388
|
ggml_metal_kargs_mul_mv_id args = {
|
|
1771
2389
|
/*.nei0 =*/ ne20,
|
|
@@ -1849,7 +2467,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1849
2467
|
/*.nb21 =*/ nb21,
|
|
1850
2468
|
};
|
|
1851
2469
|
|
|
1852
|
-
|
|
2470
|
+
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
|
|
1853
2471
|
|
|
1854
2472
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1855
2473
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -1875,20 +2493,118 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
|
|
|
1875
2493
|
return (ne01 < 20) && (ne00 % 32 == 0);
|
|
1876
2494
|
}
|
|
1877
2495
|
|
|
2496
|
+
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
|
2497
|
+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
2498
|
+
|
|
2499
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2500
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2501
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2502
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2503
|
+
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2504
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2505
|
+
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2506
|
+
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2507
|
+
|
|
2508
|
+
size_t res = 0;
|
|
2509
|
+
|
|
2510
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
2511
|
+
|
|
2512
|
+
// note: the non-vec kernel requires more extra memory, so always reserve for it
|
|
2513
|
+
GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
|
|
2514
|
+
|
|
2515
|
+
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2516
|
+
if (false) {
|
|
2517
|
+
// note: always reserve the padding space to avoid graph reallocations
|
|
2518
|
+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
|
2519
|
+
const bool has_kvpad = true;
|
|
2520
|
+
|
|
2521
|
+
if (has_kvpad) {
|
|
2522
|
+
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
|
2523
|
+
nb11*ne12*ne13 +
|
|
2524
|
+
nb21*ne22*ne23 +
|
|
2525
|
+
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
2526
|
+
}
|
|
2527
|
+
} else {
|
|
2528
|
+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
|
2529
|
+
const bool has_kvpad = true;
|
|
2530
|
+
|
|
2531
|
+
if (has_kvpad) {
|
|
2532
|
+
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
|
2533
|
+
nb11*ne12*ne13 +
|
|
2534
|
+
nb21*ne22*ne23 +
|
|
2535
|
+
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
2536
|
+
}
|
|
2537
|
+
}
|
|
2538
|
+
|
|
2539
|
+
return res;
|
|
2540
|
+
}
|
|
2541
|
+
|
|
2542
|
+
size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
|
2543
|
+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
2544
|
+
|
|
2545
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2546
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2547
|
+
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2548
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2549
|
+
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2550
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2551
|
+
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2552
|
+
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2553
|
+
|
|
2554
|
+
size_t res = 0;
|
|
2555
|
+
|
|
2556
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
2557
|
+
|
|
2558
|
+
if (!has_mask) {
|
|
2559
|
+
return res;
|
|
2560
|
+
}
|
|
2561
|
+
|
|
2562
|
+
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
|
|
2563
|
+
|
|
2564
|
+
// this optimization is not useful for the vector kernels
|
|
2565
|
+
// note: always reserve the blk buffer to avoid graph reallocations
|
|
2566
|
+
//if (is_vec) {
|
|
2567
|
+
// return res;
|
|
2568
|
+
//}
|
|
2569
|
+
|
|
2570
|
+
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
|
|
2571
|
+
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
|
2572
|
+
|
|
2573
|
+
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
|
2574
|
+
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
|
|
2575
|
+
|
|
2576
|
+
res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
|
|
2577
|
+
|
|
2578
|
+
return res;
|
|
2579
|
+
}
|
|
2580
|
+
|
|
1878
2581
|
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
|
1879
2582
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
1880
2583
|
|
|
1881
|
-
|
|
2584
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2585
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2586
|
+
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2587
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2588
|
+
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2589
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2590
|
+
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2591
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2592
|
+
|
|
2593
|
+
size_t res = 0;
|
|
1882
2594
|
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
2595
|
+
// note: always reserve the temp buffer to avoid graph reallocations
|
|
2596
|
+
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2597
|
+
if (true) {
|
|
2598
|
+
const int64_t nwg = 32;
|
|
2599
|
+
const int64_t ne01_max = std::min(ne01, 32);
|
|
1887
2600
|
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
2601
|
+
// temp buffer for writing the results from each workgroup
|
|
2602
|
+
// - ne20: the size of the Value head
|
|
2603
|
+
// - + 2: the S and M values for each intermediate result
|
|
2604
|
+
res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
|
|
2605
|
+
}
|
|
2606
|
+
|
|
2607
|
+
return res;
|
|
1892
2608
|
}
|
|
1893
2609
|
|
|
1894
2610
|
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
@@ -1910,8 +2626,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
1910
2626
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1911
2627
|
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
|
1912
2628
|
|
|
1913
|
-
GGML_ASSERT(ne00 % 4
|
|
1914
|
-
GGML_ASSERT(ne11 % 32 == 0);
|
|
2629
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
|
1915
2630
|
|
|
1916
2631
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
1917
2632
|
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
|
@@ -1921,8 +2636,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
1921
2636
|
GGML_ASSERT(ne12 == ne22);
|
|
1922
2637
|
|
|
1923
2638
|
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
|
|
1924
|
-
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >=
|
|
1925
|
-
"the Flash-Attention Metal kernel requires the mask to be
|
|
2639
|
+
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
|
2640
|
+
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
|
1926
2641
|
|
|
1927
2642
|
float scale;
|
|
1928
2643
|
float max_bias;
|
|
@@ -1949,15 +2664,107 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
1949
2664
|
|
|
1950
2665
|
GGML_ASSERT(ne01 < 65536);
|
|
1951
2666
|
|
|
2667
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
2668
|
+
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
2669
|
+
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
|
|
2670
|
+
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
|
2671
|
+
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
|
2672
|
+
|
|
2673
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
2674
|
+
|
|
2675
|
+
ggml_metal_buffer_id bid_pad = bid_dst;
|
|
2676
|
+
bid_pad.offs += ggml_nbytes(op);
|
|
2677
|
+
|
|
2678
|
+
ggml_metal_buffer_id bid_blk = bid_pad;
|
|
2679
|
+
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
|
|
2680
|
+
|
|
2681
|
+
ggml_metal_buffer_id bid_tmp = bid_blk;
|
|
2682
|
+
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
|
|
2683
|
+
|
|
1952
2684
|
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
1953
2685
|
// half8x8 kernel
|
|
1954
|
-
const
|
|
1955
|
-
const
|
|
2686
|
+
const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
|
|
2687
|
+
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
|
1956
2688
|
|
|
1957
2689
|
GGML_ASSERT(nqptg <= 32);
|
|
1958
2690
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
1959
2691
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
1960
2692
|
|
|
2693
|
+
bool need_sync = false;
|
|
2694
|
+
|
|
2695
|
+
const bool has_kvpad = ne11 % ncpsg != 0;
|
|
2696
|
+
|
|
2697
|
+
if (has_kvpad) {
|
|
2698
|
+
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
2699
|
+
|
|
2700
|
+
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
2701
|
+
/*.ne11 =*/ne11,
|
|
2702
|
+
/*.ne_12_2 =*/ne12,
|
|
2703
|
+
/*.ne_12_3 =*/ne13,
|
|
2704
|
+
/*.nb11 =*/nb11,
|
|
2705
|
+
/*.nb12 =*/nb12,
|
|
2706
|
+
/*.nb13 =*/nb13,
|
|
2707
|
+
/*.nb21 =*/nb21,
|
|
2708
|
+
/*.nb22 =*/nb22,
|
|
2709
|
+
/*.nb23 =*/nb23,
|
|
2710
|
+
/*.ne31 =*/ne31,
|
|
2711
|
+
/*.ne32 =*/ne32,
|
|
2712
|
+
/*.ne33 =*/ne33,
|
|
2713
|
+
/*.nb31 =*/nb31,
|
|
2714
|
+
/*.nb32 =*/nb32,
|
|
2715
|
+
/*.nb33 =*/nb33,
|
|
2716
|
+
};
|
|
2717
|
+
|
|
2718
|
+
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2719
|
+
|
|
2720
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2721
|
+
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2722
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
2723
|
+
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
2724
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
2725
|
+
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
2726
|
+
|
|
2727
|
+
assert(ne12 == ne22);
|
|
2728
|
+
assert(ne13 == ne23);
|
|
2729
|
+
|
|
2730
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2731
|
+
|
|
2732
|
+
need_sync = true;
|
|
2733
|
+
}
|
|
2734
|
+
|
|
2735
|
+
if (has_mask) {
|
|
2736
|
+
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
|
|
2737
|
+
|
|
2738
|
+
ggml_metal_kargs_flash_attn_ext_blk args0 = {
|
|
2739
|
+
/*.ne01 =*/ ne01,
|
|
2740
|
+
/*.ne30 =*/ ne30,
|
|
2741
|
+
/*.ne31 =*/ ne31,
|
|
2742
|
+
/*.ne32 =*/ ne32,
|
|
2743
|
+
/*.ne33 =*/ ne33,
|
|
2744
|
+
/*.nb31 =*/ nb31,
|
|
2745
|
+
/*.nb32 =*/ nb32,
|
|
2746
|
+
/*.nb33 =*/ nb33,
|
|
2747
|
+
};
|
|
2748
|
+
|
|
2749
|
+
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
|
2750
|
+
|
|
2751
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2752
|
+
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2753
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
|
|
2754
|
+
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
|
|
2755
|
+
|
|
2756
|
+
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
|
|
2757
|
+
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
|
|
2758
|
+
|
|
2759
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
|
2760
|
+
|
|
2761
|
+
need_sync = true;
|
|
2762
|
+
}
|
|
2763
|
+
|
|
2764
|
+
if (need_sync) {
|
|
2765
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
2766
|
+
}
|
|
2767
|
+
|
|
1961
2768
|
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
|
1962
2769
|
|
|
1963
2770
|
// 2*(2*ncpsg)
|
|
@@ -1985,7 +2792,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
1985
2792
|
|
|
1986
2793
|
// simdgroups per threadgroup (a.k.a. warps)
|
|
1987
2794
|
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
|
1988
|
-
int32_t nsg = 4;
|
|
2795
|
+
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
|
1989
2796
|
|
|
1990
2797
|
const size_t smem = FATTN_SMEM(nsg);
|
|
1991
2798
|
|
|
@@ -2007,6 +2814,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2007
2814
|
/*.nb21 =*/ nb21,
|
|
2008
2815
|
/*.nb22 =*/ nb22,
|
|
2009
2816
|
/*.nb23 =*/ nb23,
|
|
2817
|
+
/*.ne31 =*/ ne31,
|
|
2010
2818
|
/*.ne32 =*/ ne32,
|
|
2011
2819
|
/*.ne33 =*/ ne33,
|
|
2012
2820
|
/*.nb31 =*/ nb31,
|
|
@@ -2023,24 +2831,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2023
2831
|
/*.logit_softcap =*/ logit_softcap,
|
|
2024
2832
|
};
|
|
2025
2833
|
|
|
2026
|
-
|
|
2834
|
+
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
|
2027
2835
|
|
|
2028
2836
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2029
2837
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2030
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2031
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2032
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
if (op->src[4]) {
|
|
2039
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
2040
|
-
} else {
|
|
2041
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
2042
|
-
}
|
|
2043
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
|
|
2838
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2839
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2840
|
+
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
2841
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
2842
|
+
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
2843
|
+
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
|
2844
|
+
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
|
|
2845
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
|
2044
2846
|
|
|
2045
2847
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2046
2848
|
|
|
@@ -2048,14 +2850,63 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2048
2850
|
#undef FATTN_SMEM
|
|
2049
2851
|
} else {
|
|
2050
2852
|
// half4x4 kernel
|
|
2051
|
-
const
|
|
2052
|
-
const
|
|
2053
|
-
const
|
|
2853
|
+
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
|
|
2854
|
+
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
2855
|
+
const int nhptg = 1; // heads per threadgroup
|
|
2054
2856
|
|
|
2055
2857
|
GGML_ASSERT(nqptg <= 32);
|
|
2056
2858
|
GGML_ASSERT(nqptg % 1 == 0);
|
|
2057
2859
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
2058
2860
|
|
|
2861
|
+
bool need_sync = false;
|
|
2862
|
+
|
|
2863
|
+
const bool has_kvpad = ne11 % ncpsg != 0;
|
|
2864
|
+
|
|
2865
|
+
if (has_kvpad) {
|
|
2866
|
+
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
2867
|
+
|
|
2868
|
+
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
2869
|
+
/*.ne11 =*/ne11,
|
|
2870
|
+
/*.ne_12_2 =*/ne12,
|
|
2871
|
+
/*.ne_12_3 =*/ne13,
|
|
2872
|
+
/*.nb11 =*/nb11,
|
|
2873
|
+
/*.nb12 =*/nb12,
|
|
2874
|
+
/*.nb13 =*/nb13,
|
|
2875
|
+
/*.nb21 =*/nb21,
|
|
2876
|
+
/*.nb22 =*/nb22,
|
|
2877
|
+
/*.nb23 =*/nb23,
|
|
2878
|
+
/*.ne31 =*/ne31,
|
|
2879
|
+
/*.ne32 =*/ne32,
|
|
2880
|
+
/*.ne33 =*/ne33,
|
|
2881
|
+
/*.nb31 =*/nb31,
|
|
2882
|
+
/*.nb32 =*/nb32,
|
|
2883
|
+
/*.nb33 =*/nb33,
|
|
2884
|
+
};
|
|
2885
|
+
|
|
2886
|
+
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2887
|
+
|
|
2888
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2889
|
+
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2890
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
2891
|
+
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
2892
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
2893
|
+
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
2894
|
+
|
|
2895
|
+
assert(ne12 == ne22);
|
|
2896
|
+
assert(ne13 == ne23);
|
|
2897
|
+
|
|
2898
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2899
|
+
|
|
2900
|
+
need_sync = true;
|
|
2901
|
+
}
|
|
2902
|
+
|
|
2903
|
+
if (need_sync) {
|
|
2904
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
2905
|
+
}
|
|
2906
|
+
|
|
2907
|
+
// note: for simplicity assume the K is larger or equal than V
|
|
2908
|
+
GGML_ASSERT(ne10 >= ne20);
|
|
2909
|
+
|
|
2059
2910
|
// ne00 + 2*ncpsg*(nsg)
|
|
2060
2911
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
2061
2912
|
// and store the soft_max values and the mask
|
|
@@ -2063,28 +2914,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2063
2914
|
// ne20*(nsg)
|
|
2064
2915
|
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
|
2065
2916
|
//
|
|
2066
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((
|
|
2067
|
-
|
|
2068
|
-
int64_t nsgmax = 2;
|
|
2069
|
-
while (true) {
|
|
2070
|
-
const size_t smem = FATTN_SMEM(nsgmax);
|
|
2071
|
-
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
|
|
2072
|
-
if (smem > props_dev->max_theadgroup_memory_size/2) {
|
|
2073
|
-
break;
|
|
2074
|
-
}
|
|
2075
|
-
nsgmax *= 2;
|
|
2076
|
-
}
|
|
2077
|
-
nsgmax /= 2;
|
|
2078
|
-
|
|
2079
|
-
// simdgroups per threadgroup (a.k.a. warps)
|
|
2080
|
-
//const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
|
2081
|
-
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
|
|
2917
|
+
#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
|
|
2082
2918
|
|
|
2083
2919
|
int64_t nsg = 1;
|
|
2084
|
-
while (nsg <= nsgt) {
|
|
2085
|
-
nsg *= 2;
|
|
2086
|
-
}
|
|
2087
|
-
nsg /= 2;
|
|
2088
2920
|
|
|
2089
2921
|
// workgroups
|
|
2090
2922
|
// each workgroup handles nsg*nkpsg cache values
|
|
@@ -2097,7 +2929,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2097
2929
|
} else {
|
|
2098
2930
|
nwg = 32;
|
|
2099
2931
|
nsg = 1;
|
|
2100
|
-
while (2*nwg*nsg*
|
|
2932
|
+
while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
|
|
2101
2933
|
nsg *= 2;
|
|
2102
2934
|
}
|
|
2103
2935
|
}
|
|
@@ -2120,6 +2952,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2120
2952
|
/*.nb21 =*/ nb21,
|
|
2121
2953
|
/*.nb22 =*/ nb22,
|
|
2122
2954
|
/*.nb23 =*/ nb23,
|
|
2955
|
+
/*.ne31 =*/ ne31,
|
|
2123
2956
|
/*.ne32 =*/ ne32,
|
|
2124
2957
|
/*.ne33 =*/ ne33,
|
|
2125
2958
|
/*.nb31 =*/ nb31,
|
|
@@ -2136,25 +2969,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2136
2969
|
/*.logit_softcap =*/ logit_softcap,
|
|
2137
2970
|
};
|
|
2138
2971
|
|
|
2139
|
-
|
|
2972
|
+
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
|
2140
2973
|
|
|
2141
2974
|
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2142
2975
|
|
|
2143
2976
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2144
2977
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2145
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2146
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2147
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2148
|
-
|
|
2149
|
-
|
|
2150
|
-
} else {
|
|
2151
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
|
2152
|
-
}
|
|
2153
|
-
if (op->src[4]) {
|
|
2154
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
2155
|
-
} else {
|
|
2156
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
2157
|
-
}
|
|
2978
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2979
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2980
|
+
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
2981
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
2982
|
+
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
2158
2983
|
|
|
2159
2984
|
const size_t smem = FATTN_SMEM(nsg);
|
|
2160
2985
|
|
|
@@ -2162,26 +2987,28 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2162
2987
|
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
|
2163
2988
|
|
|
2164
2989
|
if (nwg == 1) {
|
|
2990
|
+
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
|
2991
|
+
|
|
2165
2992
|
// using 1 workgroup -> write the result directly into dst
|
|
2166
|
-
ggml_metal_encoder_set_buffer(enc,
|
|
2993
|
+
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
2994
|
+
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
|
2167
2995
|
|
|
2168
2996
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2169
2997
|
|
|
2170
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
2998
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
|
|
2171
2999
|
} else {
|
|
2172
3000
|
// sanity checks
|
|
3001
|
+
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
|
3002
|
+
|
|
2173
3003
|
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
|
2174
3004
|
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
|
2175
3005
|
|
|
2176
|
-
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
2177
|
-
|
|
2178
3006
|
// write the results from each workgroup into a temp buffer
|
|
2179
|
-
|
|
2180
|
-
bid_tmp
|
|
2181
|
-
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
|
3007
|
+
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
3008
|
+
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
|
2182
3009
|
|
|
2183
3010
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2184
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
3011
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
|
|
2185
3012
|
|
|
2186
3013
|
// sync the 2 kernels
|
|
2187
3014
|
ggml_metal_op_concurrency_reset(ctx);
|
|
@@ -2194,7 +3021,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2194
3021
|
nrows,
|
|
2195
3022
|
};
|
|
2196
3023
|
|
|
2197
|
-
|
|
3024
|
+
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
|
|
2198
3025
|
|
|
2199
3026
|
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2200
3027
|
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
@@ -2233,8 +3060,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2233
3060
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
2234
3061
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
|
2235
3062
|
|
|
2236
|
-
bool bcast_row = false;
|
|
2237
|
-
|
|
2238
3063
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
2239
3064
|
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
2240
3065
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
@@ -2326,20 +3151,9 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2326
3151
|
// the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
|
|
2327
3152
|
bid_src1.offs = 0;
|
|
2328
3153
|
|
|
2329
|
-
|
|
2330
|
-
|
|
2331
|
-
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2332
|
-
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
3154
|
+
struct ggml_metal_pipeline_with_params pipeline;
|
|
2333
3155
|
|
|
2334
|
-
|
|
2335
|
-
GGML_ASSERT(ne11 == 1);
|
|
2336
|
-
|
|
2337
|
-
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
|
|
2338
|
-
|
|
2339
|
-
bcast_row = true;
|
|
2340
|
-
} else {
|
|
2341
|
-
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
|
|
2342
|
-
}
|
|
3156
|
+
pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
|
|
2343
3157
|
|
|
2344
3158
|
if (n_fuse > 1) {
|
|
2345
3159
|
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
|
@@ -2353,20 +3167,26 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2353
3167
|
}
|
|
2354
3168
|
}
|
|
2355
3169
|
|
|
3170
|
+
if (pipeline.c4) {
|
|
3171
|
+
args.ne00 = ne00/4;
|
|
3172
|
+
args.ne10 = ne10/4;
|
|
3173
|
+
args.ne0 = ne0/4;
|
|
3174
|
+
}
|
|
3175
|
+
|
|
2356
3176
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2357
3177
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2358
3178
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2359
3179
|
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2360
3180
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
|
2361
3181
|
|
|
2362
|
-
if (
|
|
2363
|
-
|
|
2364
|
-
|
|
2365
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
3182
|
+
if (pipeline.cnt) {
|
|
3183
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
|
|
2366
3184
|
} else {
|
|
2367
|
-
int
|
|
3185
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
3186
|
+
|
|
3187
|
+
int nth = 1;
|
|
2368
3188
|
|
|
2369
|
-
while (
|
|
3189
|
+
while (2*nth < args.ne0 && nth < nth_max) {
|
|
2370
3190
|
nth *= 2;
|
|
2371
3191
|
}
|
|
2372
3192
|
|
|
@@ -2385,41 +3205,61 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2385
3205
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2386
3206
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2387
3207
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2388
|
-
GGML_TENSOR_LOCALS(
|
|
3208
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3209
|
+
|
|
3210
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
3211
|
+
|
|
3212
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
3213
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
2389
3214
|
|
|
2390
3215
|
float eps;
|
|
2391
3216
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
2392
3217
|
|
|
2393
|
-
int nth = 32; // SIMD width
|
|
2394
|
-
|
|
2395
3218
|
ggml_metal_kargs_l2_norm args = {
|
|
2396
|
-
/*.ne00
|
|
2397
|
-
/*.
|
|
2398
|
-
/*.
|
|
2399
|
-
/*.
|
|
3219
|
+
/*.ne00 =*/ ne00,
|
|
3220
|
+
/*.ne01 =*/ ne01,
|
|
3221
|
+
/*.ne02 =*/ ne02,
|
|
3222
|
+
/*.ne03 =*/ ne03,
|
|
3223
|
+
/*.nb00 =*/ nb00,
|
|
3224
|
+
/*.nb01 =*/ nb01,
|
|
3225
|
+
/*.nb02 =*/ nb02,
|
|
3226
|
+
/*.nb03 =*/ nb03,
|
|
3227
|
+
/*.ne0 =*/ ne0,
|
|
3228
|
+
/*.ne1 =*/ ne1,
|
|
3229
|
+
/*.ne2 =*/ ne2,
|
|
3230
|
+
/*.ne3 =*/ ne3,
|
|
3231
|
+
/*.nb0 =*/ nb0,
|
|
3232
|
+
/*.nb1 =*/ nb1,
|
|
3233
|
+
/*.nb2 =*/ nb2,
|
|
3234
|
+
/*.nb3 =*/ nb3,
|
|
3235
|
+
/*.eps =*/ eps,
|
|
2400
3236
|
};
|
|
2401
3237
|
|
|
2402
|
-
|
|
3238
|
+
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
|
3239
|
+
|
|
3240
|
+
if (pipeline.c4) {
|
|
3241
|
+
args.ne00 = ne00/4;
|
|
3242
|
+
args.ne0 = ne0/4;
|
|
3243
|
+
}
|
|
3244
|
+
|
|
3245
|
+
int nth = 32; // SIMD width
|
|
2403
3246
|
|
|
2404
|
-
while (nth < ne00
|
|
3247
|
+
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
2405
3248
|
nth *= 2;
|
|
2406
3249
|
}
|
|
2407
3250
|
|
|
2408
3251
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2409
|
-
nth = std::min(nth, ne00/4);
|
|
2410
3252
|
|
|
2411
|
-
const size_t smem =
|
|
2412
|
-
|
|
2413
|
-
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
3253
|
+
const size_t smem = pipeline.smem;
|
|
2414
3254
|
|
|
2415
3255
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2416
3256
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2417
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2418
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
3257
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3258
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
2419
3259
|
|
|
2420
3260
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2421
3261
|
|
|
2422
|
-
ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
3262
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
2423
3263
|
|
|
2424
3264
|
return 1;
|
|
2425
3265
|
}
|
|
@@ -2433,7 +3273,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2433
3273
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2434
3274
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2435
3275
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2436
|
-
GGML_TENSOR_LOCALS(
|
|
3276
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2437
3277
|
|
|
2438
3278
|
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
|
|
2439
3279
|
|
|
@@ -2451,7 +3291,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2451
3291
|
/*.eps =*/ eps,
|
|
2452
3292
|
};
|
|
2453
3293
|
|
|
2454
|
-
|
|
3294
|
+
auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
|
|
2455
3295
|
|
|
2456
3296
|
int nth = 32; // SIMD width
|
|
2457
3297
|
//while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
@@ -2461,7 +3301,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2461
3301
|
//nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2462
3302
|
//nth = std::min(nth, ne00/4);
|
|
2463
3303
|
|
|
2464
|
-
const size_t smem =
|
|
3304
|
+
const size_t smem = pipeline.smem;
|
|
2465
3305
|
|
|
2466
3306
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2467
3307
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2488,7 +3328,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2488
3328
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2489
3329
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2490
3330
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2491
|
-
GGML_TENSOR_LOCALS(
|
|
3331
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2492
3332
|
|
|
2493
3333
|
float eps;
|
|
2494
3334
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
@@ -2586,7 +3426,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2586
3426
|
}
|
|
2587
3427
|
}
|
|
2588
3428
|
|
|
2589
|
-
|
|
3429
|
+
auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
|
2590
3430
|
|
|
2591
3431
|
int nth = 32; // SIMD width
|
|
2592
3432
|
|
|
@@ -2597,7 +3437,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2597
3437
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2598
3438
|
nth = std::min(nth, args.ne00_t);
|
|
2599
3439
|
|
|
2600
|
-
const size_t smem =
|
|
3440
|
+
const size_t smem = pipeline.smem;
|
|
2601
3441
|
|
|
2602
3442
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2603
3443
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2624,7 +3464,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
|
|
2624
3464
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2625
3465
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2626
3466
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2627
|
-
GGML_TENSOR_LOCALS(
|
|
3467
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2628
3468
|
|
|
2629
3469
|
// make sure we have one or more position id(ne10) per token(ne02)
|
|
2630
3470
|
GGML_ASSERT(ne10 % ne02 == 0);
|
|
@@ -2688,9 +3528,10 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
|
|
2688
3528
|
/* sect_1 =*/ sect_1,
|
|
2689
3529
|
/* sect_2 =*/ sect_2,
|
|
2690
3530
|
/* sect_3 =*/ sect_3,
|
|
3531
|
+
/* src2 =*/ op->src[2] != nullptr,
|
|
2691
3532
|
};
|
|
2692
3533
|
|
|
2693
|
-
|
|
3534
|
+
auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
|
2694
3535
|
|
|
2695
3536
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2696
3537
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2717,7 +3558,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
|
|
2717
3558
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2718
3559
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2719
3560
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2720
|
-
GGML_TENSOR_LOCALS(
|
|
3561
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2721
3562
|
|
|
2722
3563
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
2723
3564
|
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
|
@@ -2762,7 +3603,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
|
|
2762
3603
|
/*.KHW =*/ KH * KW,
|
|
2763
3604
|
};
|
|
2764
3605
|
|
|
2765
|
-
|
|
3606
|
+
auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
|
|
2766
3607
|
|
|
2767
3608
|
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2768
3609
|
|
|
@@ -2770,15 +3611,138 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
|
|
2770
3611
|
|
|
2771
3612
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2772
3613
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2773
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
|
2774
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3614
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
|
3615
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3616
|
+
|
|
3617
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
|
|
3618
|
+
|
|
3619
|
+
return 1;
|
|
3620
|
+
}
|
|
3621
|
+
|
|
3622
|
+
int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
|
|
3623
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3624
|
+
|
|
3625
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
3626
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
3627
|
+
|
|
3628
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3629
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3630
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3631
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3632
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3633
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3634
|
+
|
|
3635
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
3636
|
+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
3637
|
+
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
3638
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
|
3639
|
+
|
|
3640
|
+
const int32_t s0 = ((const int32_t *) op->op_params)[0];
|
|
3641
|
+
const int32_t s1 = ((const int32_t *) op->op_params)[1];
|
|
3642
|
+
const int32_t p0 = ((const int32_t *) op->op_params)[2];
|
|
3643
|
+
const int32_t p1 = ((const int32_t *) op->op_params)[3];
|
|
3644
|
+
const int32_t d0 = ((const int32_t *) op->op_params)[4];
|
|
3645
|
+
const int32_t d1 = ((const int32_t *) op->op_params)[5];
|
|
3646
|
+
|
|
3647
|
+
ggml_metal_kargs_conv_2d args = {
|
|
3648
|
+
/*.nb00 =*/ nb00,
|
|
3649
|
+
/*.nb01 =*/ nb01,
|
|
3650
|
+
/*.nb02 =*/ nb02,
|
|
3651
|
+
/*.nb03 =*/ nb03,
|
|
3652
|
+
/*.nb10 =*/ nb10,
|
|
3653
|
+
/*.nb11 =*/ nb11,
|
|
3654
|
+
/*.nb12 =*/ nb12,
|
|
3655
|
+
/*.nb13 =*/ nb13,
|
|
3656
|
+
/*.nb0 =*/ nb0,
|
|
3657
|
+
/*.nb1 =*/ nb1,
|
|
3658
|
+
/*.nb2 =*/ nb2,
|
|
3659
|
+
/*.nb3 =*/ nb3,
|
|
3660
|
+
/*.IW =*/ ne10,
|
|
3661
|
+
/*.IH =*/ ne11,
|
|
3662
|
+
/*.KW =*/ ne00,
|
|
3663
|
+
/*.KH =*/ ne01,
|
|
3664
|
+
/*.IC =*/ ne02,
|
|
3665
|
+
/*.OC =*/ ne03,
|
|
3666
|
+
/*.OW =*/ ne0,
|
|
3667
|
+
/*.OH =*/ ne1,
|
|
3668
|
+
/*.N =*/ ne3,
|
|
3669
|
+
/*.s0 =*/ s0,
|
|
3670
|
+
/*.s1 =*/ s1,
|
|
3671
|
+
/*.p0 =*/ p0,
|
|
3672
|
+
/*.p1 =*/ p1,
|
|
3673
|
+
/*.d0 =*/ d0,
|
|
3674
|
+
/*.d1 =*/ d1,
|
|
3675
|
+
};
|
|
3676
|
+
|
|
3677
|
+
auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
|
3678
|
+
|
|
3679
|
+
int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
|
|
3680
|
+
nth = std::min(nth, 256);
|
|
3681
|
+
nth = std::max(nth, 1);
|
|
3682
|
+
|
|
3683
|
+
const uint64_t n_out = ggml_nelements(op);
|
|
3684
|
+
|
|
3685
|
+
uint64_t tg = (n_out + nth - 1)/nth;
|
|
3686
|
+
tg = std::max<uint64_t>(tg, 1);
|
|
3687
|
+
tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
|
|
3688
|
+
|
|
3689
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3690
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3691
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3692
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3693
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
3694
|
+
|
|
3695
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
|
|
3696
|
+
|
|
3697
|
+
return 1;
|
|
3698
|
+
}
|
|
3699
|
+
|
|
3700
|
+
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|
3701
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3702
|
+
|
|
3703
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
3704
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
3705
|
+
|
|
3706
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3707
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3708
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3709
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3710
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3711
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3712
|
+
|
|
3713
|
+
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
3714
|
+
|
|
3715
|
+
const int32_t IC = op->src[1]->ne[1];
|
|
3716
|
+
const int32_t IL = op->src[1]->ne[0];
|
|
3717
|
+
|
|
3718
|
+
const int32_t K = op->src[0]->ne[0];
|
|
3719
|
+
|
|
3720
|
+
const int32_t OL = op->ne[0];
|
|
3721
|
+
const int32_t OC = op->ne[1];
|
|
3722
|
+
|
|
3723
|
+
ggml_metal_kargs_conv_transpose_1d args = {
|
|
3724
|
+
/*.IC =*/ IC,
|
|
3725
|
+
/*.IL =*/ IL,
|
|
3726
|
+
/*.K =*/ K,
|
|
3727
|
+
/*.s0 =*/ s0,
|
|
3728
|
+
/*.nb0 =*/ nb0,
|
|
3729
|
+
/*.nb1 =*/ nb1,
|
|
3730
|
+
};
|
|
3731
|
+
|
|
3732
|
+
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
|
|
3733
|
+
|
|
3734
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3735
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3736
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3737
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3738
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
2775
3739
|
|
|
2776
|
-
ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
3740
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
|
|
2777
3741
|
|
|
2778
3742
|
return 1;
|
|
2779
3743
|
}
|
|
2780
3744
|
|
|
2781
|
-
int
|
|
3745
|
+
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
|
2782
3746
|
ggml_tensor * op = ctx->node(idx);
|
|
2783
3747
|
|
|
2784
3748
|
ggml_metal_library_t lib = ctx->lib;
|
|
@@ -2789,28 +3753,35 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2789
3753
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2790
3754
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2791
3755
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2792
|
-
GGML_TENSOR_LOCALS(
|
|
3756
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2793
3757
|
|
|
2794
3758
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
2795
3759
|
|
|
2796
|
-
const int32_t IC = op->src[1]->ne[
|
|
2797
|
-
const int32_t
|
|
3760
|
+
const int32_t IC = op->src[1]->ne[2];
|
|
3761
|
+
const int32_t IH = op->src[1]->ne[1];
|
|
3762
|
+
const int32_t IW = op->src[1]->ne[0];
|
|
2798
3763
|
|
|
2799
|
-
const int32_t
|
|
3764
|
+
const int32_t KH = op->src[0]->ne[1];
|
|
3765
|
+
const int32_t KW = op->src[0]->ne[0];
|
|
2800
3766
|
|
|
2801
|
-
const int32_t
|
|
2802
|
-
const int32_t
|
|
3767
|
+
const int32_t OW = op->ne[0];
|
|
3768
|
+
const int32_t OH = op->ne[1];
|
|
3769
|
+
const int32_t OC = op->ne[2];
|
|
2803
3770
|
|
|
2804
|
-
|
|
3771
|
+
ggml_metal_kargs_conv_transpose_2d args = {
|
|
2805
3772
|
/*.IC =*/ IC,
|
|
2806
|
-
/*.
|
|
2807
|
-
/*.
|
|
3773
|
+
/*.IH =*/ IH,
|
|
3774
|
+
/*.IW =*/ IW,
|
|
3775
|
+
/*.KH =*/ KH,
|
|
3776
|
+
/*.KW =*/ KW,
|
|
3777
|
+
/*.OC =*/ OC,
|
|
2808
3778
|
/*.s0 =*/ s0,
|
|
2809
3779
|
/*.nb0 =*/ nb0,
|
|
2810
3780
|
/*.nb1 =*/ nb1,
|
|
3781
|
+
/*.nb2 =*/ nb2,
|
|
2811
3782
|
};
|
|
2812
3783
|
|
|
2813
|
-
|
|
3784
|
+
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
|
2814
3785
|
|
|
2815
3786
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2816
3787
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2818,7 +3789,11 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2818
3789
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
2819
3790
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
2820
3791
|
|
|
2821
|
-
|
|
3792
|
+
// Metal requires buffer size to be multiple of 16 bytes
|
|
3793
|
+
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
|
|
3794
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3795
|
+
|
|
3796
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
|
|
2822
3797
|
|
|
2823
3798
|
return 1;
|
|
2824
3799
|
}
|
|
@@ -2832,37 +3807,48 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
|
|
2832
3807
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2833
3808
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2834
3809
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2835
|
-
GGML_TENSOR_LOCALS(
|
|
3810
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3811
|
+
|
|
3812
|
+
float sf0 = (float)ne0/op->src[0]->ne[0];
|
|
3813
|
+
float sf1 = (float)ne1/op->src[0]->ne[1];
|
|
3814
|
+
float sf2 = (float)ne2/op->src[0]->ne[2];
|
|
3815
|
+
float sf3 = (float)ne3/op->src[0]->ne[3];
|
|
3816
|
+
|
|
3817
|
+
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
|
|
2836
3818
|
|
|
2837
|
-
|
|
2838
|
-
|
|
2839
|
-
|
|
2840
|
-
|
|
3819
|
+
float poffs = 0.5f;
|
|
3820
|
+
|
|
3821
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
3822
|
+
poffs = 0.0f;
|
|
3823
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
3824
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
3825
|
+
}
|
|
2841
3826
|
|
|
2842
3827
|
ggml_metal_kargs_upscale args = {
|
|
2843
|
-
/*.ne00
|
|
2844
|
-
/*.ne01
|
|
2845
|
-
/*.ne02
|
|
2846
|
-
/*.ne03
|
|
2847
|
-
/*.nb00
|
|
2848
|
-
/*.nb01
|
|
2849
|
-
/*.nb02
|
|
2850
|
-
/*.nb03
|
|
2851
|
-
/*.ne0
|
|
2852
|
-
/*.ne1
|
|
2853
|
-
/*.ne2
|
|
2854
|
-
/*.ne3
|
|
2855
|
-
/*.nb0
|
|
2856
|
-
/*.nb1
|
|
2857
|
-
/*.nb2
|
|
2858
|
-
/*.nb3
|
|
2859
|
-
/*.sf0
|
|
2860
|
-
/*.sf1
|
|
2861
|
-
/*.sf2
|
|
2862
|
-
/*.sf3
|
|
3828
|
+
/*.ne00 =*/ ne00,
|
|
3829
|
+
/*.ne01 =*/ ne01,
|
|
3830
|
+
/*.ne02 =*/ ne02,
|
|
3831
|
+
/*.ne03 =*/ ne03,
|
|
3832
|
+
/*.nb00 =*/ nb00,
|
|
3833
|
+
/*.nb01 =*/ nb01,
|
|
3834
|
+
/*.nb02 =*/ nb02,
|
|
3835
|
+
/*.nb03 =*/ nb03,
|
|
3836
|
+
/*.ne0 =*/ ne0,
|
|
3837
|
+
/*.ne1 =*/ ne1,
|
|
3838
|
+
/*.ne2 =*/ ne2,
|
|
3839
|
+
/*.ne3 =*/ ne3,
|
|
3840
|
+
/*.nb0 =*/ nb0,
|
|
3841
|
+
/*.nb1 =*/ nb1,
|
|
3842
|
+
/*.nb2 =*/ nb2,
|
|
3843
|
+
/*.nb3 =*/ nb3,
|
|
3844
|
+
/*.sf0 =*/ sf0,
|
|
3845
|
+
/*.sf1 =*/ sf1,
|
|
3846
|
+
/*.sf2 =*/ sf2,
|
|
3847
|
+
/*.sf3 =*/ sf3,
|
|
3848
|
+
/*.poffs =*/ poffs,
|
|
2863
3849
|
};
|
|
2864
3850
|
|
|
2865
|
-
|
|
3851
|
+
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
|
2866
3852
|
|
|
2867
3853
|
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
2868
3854
|
|
|
@@ -2885,7 +3871,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
|
|
2885
3871
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2886
3872
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2887
3873
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2888
|
-
GGML_TENSOR_LOCALS(
|
|
3874
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2889
3875
|
|
|
2890
3876
|
ggml_metal_kargs_pad args = {
|
|
2891
3877
|
/*.ne00 =*/ ne00,
|
|
@@ -2906,7 +3892,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
|
|
2906
3892
|
/*.nb3 =*/ nb3
|
|
2907
3893
|
};
|
|
2908
3894
|
|
|
2909
|
-
|
|
3895
|
+
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
|
2910
3896
|
|
|
2911
3897
|
const int nth = std::min(1024, ne0);
|
|
2912
3898
|
|
|
@@ -2929,7 +3915,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2929
3915
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2930
3916
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2931
3917
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2932
|
-
GGML_TENSOR_LOCALS(
|
|
3918
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2933
3919
|
|
|
2934
3920
|
ggml_metal_kargs_pad_reflect_1d args = {
|
|
2935
3921
|
/*.ne00 =*/ ne00,
|
|
@@ -2952,7 +3938,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2952
3938
|
/*.p1 =*/ ((const int32_t *)(op->op_params))[1]
|
|
2953
3939
|
};
|
|
2954
3940
|
|
|
2955
|
-
|
|
3941
|
+
auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
|
|
2956
3942
|
|
|
2957
3943
|
const int nth = std::min(1024, ne0);
|
|
2958
3944
|
|
|
@@ -2973,7 +3959,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
|
|
2973
3959
|
ggml_metal_encoder_t enc = ctx->enc;
|
|
2974
3960
|
|
|
2975
3961
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2976
|
-
GGML_TENSOR_LOCALS(
|
|
3962
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2977
3963
|
|
|
2978
3964
|
float start;
|
|
2979
3965
|
float step;
|
|
@@ -2989,13 +3975,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
|
|
2989
3975
|
|
|
2990
3976
|
const int nth = std::min(1024, ne0);
|
|
2991
3977
|
|
|
2992
|
-
|
|
2993
|
-
|
|
2994
|
-
//[encoder setComputePipelineState:pipeline];
|
|
2995
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
|
2996
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
|
2997
|
-
|
|
2998
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
3978
|
+
auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
|
2999
3979
|
|
|
3000
3980
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3001
3981
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3015,7 +3995,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
|
|
3015
3995
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3016
3996
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3017
3997
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3018
|
-
GGML_TENSOR_LOCALS(
|
|
3998
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3019
3999
|
|
|
3020
4000
|
const int dim = op->op_params[0];
|
|
3021
4001
|
const int max_period = op->op_params[1];
|
|
@@ -3026,7 +4006,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
|
|
3026
4006
|
/*.max_period =*/ max_period,
|
|
3027
4007
|
};
|
|
3028
4008
|
|
|
3029
|
-
|
|
4009
|
+
auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
|
|
3030
4010
|
|
|
3031
4011
|
const int nth = std::max(1, std::min(1024, dim/2));
|
|
3032
4012
|
|
|
@@ -3049,14 +4029,14 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
|
|
3049
4029
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3050
4030
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3051
4031
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3052
|
-
GGML_TENSOR_LOCALS(
|
|
4032
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3053
4033
|
|
|
3054
4034
|
ggml_metal_kargs_argmax args = {
|
|
3055
4035
|
/*.ne00 = */ ne00,
|
|
3056
4036
|
/*.nb01 = */ nb01,
|
|
3057
4037
|
};
|
|
3058
4038
|
|
|
3059
|
-
|
|
4039
|
+
auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
|
|
3060
4040
|
|
|
3061
4041
|
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
3062
4042
|
|
|
@@ -3065,7 +4045,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
|
|
3065
4045
|
nth *= 2;
|
|
3066
4046
|
}
|
|
3067
4047
|
|
|
3068
|
-
const size_t smem =
|
|
4048
|
+
const size_t smem = pipeline.smem;
|
|
3069
4049
|
|
|
3070
4050
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3071
4051
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3085,74 +4065,397 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
|
|
3085
4065
|
ggml_metal_library_t lib = ctx->lib;
|
|
3086
4066
|
ggml_metal_encoder_t enc = ctx->enc;
|
|
3087
4067
|
|
|
4068
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
4069
|
+
|
|
3088
4070
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3089
4071
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3090
4072
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3091
|
-
GGML_TENSOR_LOCALS(
|
|
4073
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
4074
|
+
|
|
4075
|
+
auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
|
3092
4076
|
|
|
3093
4077
|
// bitonic sort requires the number of elements to be power of 2
|
|
3094
|
-
|
|
3095
|
-
while (
|
|
3096
|
-
|
|
4078
|
+
int nth = 1;
|
|
4079
|
+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
4080
|
+
nth *= 2;
|
|
3097
4081
|
}
|
|
3098
4082
|
|
|
3099
|
-
|
|
3100
|
-
|
|
3101
|
-
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
4083
|
+
const int npr = (ne00 + nth - 1)/nth;
|
|
3102
4084
|
|
|
3103
4085
|
// Metal kernels require the buffer size to be multiple of 16 bytes
|
|
3104
4086
|
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
|
3105
|
-
const size_t smem = GGML_PAD(
|
|
4087
|
+
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
|
|
4088
|
+
|
|
4089
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
4090
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
4091
|
+
|
|
4092
|
+
ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
4093
|
+
bid_tmp.offs += ggml_nbytes(op);
|
|
4094
|
+
|
|
4095
|
+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
|
4096
|
+
std::swap(bid_dst, bid_tmp);
|
|
4097
|
+
}
|
|
3106
4098
|
|
|
3107
4099
|
ggml_metal_kargs_argsort args = {
|
|
3108
|
-
/*.
|
|
3109
|
-
/*.
|
|
4100
|
+
/*.ne00 =*/ ne00,
|
|
4101
|
+
/*.ne01 =*/ ne01,
|
|
4102
|
+
/*.ne02 =*/ ne02,
|
|
4103
|
+
/*.ne03 =*/ ne03,
|
|
4104
|
+
/*.nb00 =*/ nb00,
|
|
4105
|
+
/*.nb01 =*/ nb01,
|
|
4106
|
+
/*.nb02 =*/ nb02,
|
|
4107
|
+
/*.nb03 =*/ nb03,
|
|
4108
|
+
/*.ne0 =*/ ne0,
|
|
4109
|
+
/*.ne1 =*/ ne1,
|
|
4110
|
+
/*.ne2 =*/ ne2,
|
|
4111
|
+
/*.ne3 =*/ ne3,
|
|
4112
|
+
/*.top_k =*/ nth,
|
|
3110
4113
|
};
|
|
3111
4114
|
|
|
3112
4115
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3113
4116
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3114
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
3115
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
4117
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
4118
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3116
4119
|
|
|
3117
4120
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3118
4121
|
|
|
3119
|
-
ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
4122
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
|
4123
|
+
|
|
4124
|
+
auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
|
4125
|
+
|
|
4126
|
+
int len = nth;
|
|
4127
|
+
|
|
4128
|
+
while (len < ne00) {
|
|
4129
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
4130
|
+
|
|
4131
|
+
ggml_metal_kargs_argsort_merge args_merge = {
|
|
4132
|
+
/*.ne00 =*/ ne00,
|
|
4133
|
+
/*.ne01 =*/ ne01,
|
|
4134
|
+
/*.ne02 =*/ ne02,
|
|
4135
|
+
/*.ne03 =*/ ne03,
|
|
4136
|
+
/*.nb00 =*/ nb00,
|
|
4137
|
+
/*.nb01 =*/ nb01,
|
|
4138
|
+
/*.nb02 =*/ nb02,
|
|
4139
|
+
/*.nb03 =*/ nb03,
|
|
4140
|
+
/*.ne0 =*/ ne0,
|
|
4141
|
+
/*.ne1 =*/ ne1,
|
|
4142
|
+
/*.ne2 =*/ ne2,
|
|
4143
|
+
/*.ne3 =*/ ne3,
|
|
4144
|
+
/*.top_k =*/ ne00,
|
|
4145
|
+
/*.len =*/ len,
|
|
4146
|
+
};
|
|
4147
|
+
|
|
4148
|
+
// merges per row
|
|
4149
|
+
const int nm = (ne00 + 2*len - 1) / (2*len);
|
|
4150
|
+
|
|
4151
|
+
const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
|
|
4152
|
+
|
|
4153
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
|
4154
|
+
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
|
4155
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
4156
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
4157
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
4158
|
+
|
|
4159
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
|
4160
|
+
|
|
4161
|
+
std::swap(bid_dst, bid_tmp);
|
|
4162
|
+
|
|
4163
|
+
len <<= 1;
|
|
4164
|
+
}
|
|
3120
4165
|
|
|
3121
4166
|
return 1;
|
|
3122
4167
|
}
|
|
3123
4168
|
|
|
3124
|
-
int
|
|
4169
|
+
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
|
3125
4170
|
ggml_tensor * op = ctx->node(idx);
|
|
3126
4171
|
|
|
3127
4172
|
ggml_metal_library_t lib = ctx->lib;
|
|
3128
4173
|
ggml_metal_encoder_t enc = ctx->enc;
|
|
3129
4174
|
|
|
4175
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
4176
|
+
|
|
3130
4177
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3131
4178
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3132
4179
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3133
|
-
GGML_TENSOR_LOCALS(
|
|
4180
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
4181
|
+
|
|
4182
|
+
auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
|
|
4183
|
+
|
|
4184
|
+
// bitonic sort requires the number of elements to be power of 2
|
|
4185
|
+
int nth = 1;
|
|
4186
|
+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
4187
|
+
nth *= 2;
|
|
4188
|
+
}
|
|
4189
|
+
|
|
4190
|
+
// blocks per row
|
|
4191
|
+
const int npr = (ne00 + nth - 1)/nth;
|
|
4192
|
+
|
|
4193
|
+
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
|
|
4194
|
+
|
|
4195
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
4196
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
4197
|
+
|
|
4198
|
+
ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
4199
|
+
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
|
|
4200
|
+
|
|
4201
|
+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
|
4202
|
+
std::swap(bid_dst, bid_tmp);
|
|
4203
|
+
}
|
|
4204
|
+
|
|
4205
|
+
const int top_k = ne0;
|
|
4206
|
+
|
|
4207
|
+
ggml_metal_kargs_argsort args = {
|
|
4208
|
+
/*.ne00 =*/ ne00,
|
|
4209
|
+
/*.ne01 =*/ ne01,
|
|
4210
|
+
/*.ne02 =*/ ne02,
|
|
4211
|
+
/*.ne03 =*/ ne03,
|
|
4212
|
+
/*.nb00 =*/ nb00,
|
|
4213
|
+
/*.nb01 =*/ nb01,
|
|
4214
|
+
/*.nb02 =*/ nb02,
|
|
4215
|
+
/*.nb03 =*/ nb03,
|
|
4216
|
+
/*.ne0 =*/ ne0,
|
|
4217
|
+
/*.ne1 =*/ ne1,
|
|
4218
|
+
/*.ne2 =*/ ne2,
|
|
4219
|
+
/*.ne3 =*/ ne3,
|
|
4220
|
+
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
|
|
4221
|
+
};
|
|
4222
|
+
|
|
4223
|
+
if (npr > 1) {
|
|
4224
|
+
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
|
|
4225
|
+
}
|
|
4226
|
+
|
|
4227
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4228
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
4229
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
4230
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
4231
|
+
|
|
4232
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
4233
|
+
|
|
4234
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
|
4235
|
+
|
|
4236
|
+
auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
|
4237
|
+
|
|
4238
|
+
int len = args.top_k;
|
|
4239
|
+
|
|
4240
|
+
while (len < args.ne0) {
|
|
4241
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
4242
|
+
|
|
4243
|
+
// merges per row
|
|
4244
|
+
const int nm = (args.ne0 + 2*len - 1) / (2*len);
|
|
4245
|
+
|
|
4246
|
+
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
|
|
4247
|
+
|
|
4248
|
+
ggml_metal_kargs_argsort_merge args_merge = {
|
|
4249
|
+
/*.ne00 =*/ ne00,
|
|
4250
|
+
/*.ne01 =*/ ne01,
|
|
4251
|
+
/*.ne02 =*/ ne02,
|
|
4252
|
+
/*.ne03 =*/ ne03,
|
|
4253
|
+
/*.nb00 =*/ nb00,
|
|
4254
|
+
/*.nb01 =*/ nb01,
|
|
4255
|
+
/*.nb02 =*/ nb02,
|
|
4256
|
+
/*.nb03 =*/ nb03,
|
|
4257
|
+
/*.ne0 =*/ args.ne0,
|
|
4258
|
+
/*.ne1 =*/ ne1,
|
|
4259
|
+
/*.ne2 =*/ ne2,
|
|
4260
|
+
/*.ne3 =*/ ne3,
|
|
4261
|
+
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
|
|
4262
|
+
/*.len =*/ len,
|
|
4263
|
+
};
|
|
4264
|
+
|
|
4265
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
|
4266
|
+
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
|
4267
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
4268
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
4269
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
4270
|
+
|
|
4271
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
|
4272
|
+
|
|
4273
|
+
std::swap(bid_dst, bid_tmp);
|
|
4274
|
+
|
|
4275
|
+
len <<= 1;
|
|
4276
|
+
}
|
|
4277
|
+
|
|
4278
|
+
return 1;
|
|
4279
|
+
}
|
|
4280
|
+
|
|
4281
|
+
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
|
4282
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3134
4283
|
|
|
3135
|
-
|
|
3136
|
-
|
|
4284
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
4285
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
4286
|
+
|
|
4287
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
4288
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
4289
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
4290
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3137
4291
|
|
|
3138
|
-
|
|
3139
|
-
/*.
|
|
4292
|
+
ggml_metal_kargs_tri args = {
|
|
4293
|
+
/*.ne00 =*/ ne00,
|
|
4294
|
+
/*.ne01 =*/ ne01,
|
|
4295
|
+
/*.ne02 =*/ ne02,
|
|
4296
|
+
/*.ne03 =*/ ne03,
|
|
4297
|
+
/*.nb00 =*/ nb00,
|
|
4298
|
+
/*.nb01 =*/ nb01,
|
|
4299
|
+
/*.nb02 =*/ nb02,
|
|
4300
|
+
/*.nb03 =*/ nb03,
|
|
4301
|
+
/*.ne0 =*/ ne0,
|
|
4302
|
+
/*.ne1 =*/ ne1,
|
|
4303
|
+
/*.ne2 =*/ ne2,
|
|
4304
|
+
/*.ne3 =*/ ne3,
|
|
4305
|
+
/*.nb0 =*/ nb0,
|
|
4306
|
+
/*.nb1 =*/ nb1,
|
|
4307
|
+
/*.nb2 =*/ nb2,
|
|
4308
|
+
/*.nb3 =*/ nb3,
|
|
3140
4309
|
};
|
|
3141
4310
|
|
|
3142
|
-
|
|
4311
|
+
auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
|
|
3143
4312
|
|
|
3144
|
-
|
|
4313
|
+
int nth = 32; // SIMD width
|
|
3145
4314
|
|
|
3146
|
-
|
|
3147
|
-
|
|
4315
|
+
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
4316
|
+
nth *= 2;
|
|
3148
4317
|
}
|
|
3149
4318
|
|
|
4319
|
+
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
4320
|
+
nth = std::min(nth, ne00);
|
|
4321
|
+
|
|
3150
4322
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3151
4323
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3152
4324
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3153
4325
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3154
4326
|
|
|
3155
|
-
ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
4327
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
4328
|
+
|
|
4329
|
+
return 1;
|
|
4330
|
+
}
|
|
4331
|
+
|
|
4332
|
+
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
|
4333
|
+
ggml_tensor * op = ctx->node(idx);
|
|
4334
|
+
|
|
4335
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
4336
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
4337
|
+
|
|
4338
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
4339
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
4340
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
4341
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
4342
|
+
|
|
4343
|
+
auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
|
4344
|
+
|
|
4345
|
+
const int64_t np = ggml_nelements(op->src[0]);
|
|
4346
|
+
ggml_metal_kargs_opt_step_adamw args = {
|
|
4347
|
+
/*.np =*/ np,
|
|
4348
|
+
};
|
|
4349
|
+
|
|
4350
|
+
int ida = 0;
|
|
4351
|
+
|
|
4352
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4353
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
4354
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
4355
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
4356
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
4357
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
|
4358
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
|
4359
|
+
|
|
4360
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
4361
|
+
const int64_t n = (np + nth - 1) / nth;
|
|
4362
|
+
|
|
4363
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
4364
|
+
|
|
4365
|
+
return 1;
|
|
4366
|
+
}
|
|
4367
|
+
|
|
4368
|
+
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
|
4369
|
+
ggml_tensor * op = ctx->node(idx);
|
|
4370
|
+
|
|
4371
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
4372
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
4373
|
+
|
|
4374
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
4375
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
4376
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
4377
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
4378
|
+
|
|
4379
|
+
auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
|
4380
|
+
|
|
4381
|
+
const int64_t np = ggml_nelements(op->src[0]);
|
|
4382
|
+
ggml_metal_kargs_opt_step_sgd args = {
|
|
4383
|
+
/*.np =*/ np,
|
|
4384
|
+
};
|
|
4385
|
+
|
|
4386
|
+
int ida = 0;
|
|
4387
|
+
|
|
4388
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4389
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
4390
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
4391
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
4392
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
4393
|
+
|
|
4394
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
4395
|
+
const int64_t n = (np + nth - 1) / nth;
|
|
4396
|
+
|
|
4397
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
4398
|
+
|
|
4399
|
+
return 1;
|
|
4400
|
+
}
|
|
4401
|
+
|
|
4402
|
+
int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
|
|
4403
|
+
ggml_tensor * op = ctx->node(idx);
|
|
4404
|
+
|
|
4405
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
4406
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
4407
|
+
|
|
4408
|
+
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
|
4409
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
4410
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
4411
|
+
|
|
4412
|
+
{
|
|
4413
|
+
ggml_metal_kargs_memset args = { /*.val =*/ 0 };
|
|
4414
|
+
|
|
4415
|
+
auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
|
|
4416
|
+
|
|
4417
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4418
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
4419
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
|
|
4420
|
+
|
|
4421
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
|
4422
|
+
}
|
|
4423
|
+
|
|
4424
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
4425
|
+
|
|
4426
|
+
{
|
|
4427
|
+
ggml_metal_kargs_count_equal args = {
|
|
4428
|
+
/*.ne00 =*/ ne00,
|
|
4429
|
+
/*.ne01 =*/ ne01,
|
|
4430
|
+
/*.ne02 =*/ ne02,
|
|
4431
|
+
/*.ne03 =*/ ne03,
|
|
4432
|
+
/*.nb00 =*/ nb00,
|
|
4433
|
+
/*.nb01 =*/ nb01,
|
|
4434
|
+
/*.nb02 =*/ nb02,
|
|
4435
|
+
/*.nb03 =*/ nb03,
|
|
4436
|
+
/*.nb10 =*/ nb10,
|
|
4437
|
+
/*.nb11 =*/ nb11,
|
|
4438
|
+
/*.nb12 =*/ nb12,
|
|
4439
|
+
/*.nb13 =*/ nb13,
|
|
4440
|
+
};
|
|
4441
|
+
|
|
4442
|
+
auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
|
|
4443
|
+
|
|
4444
|
+
const size_t smem = pipeline.smem;
|
|
4445
|
+
|
|
4446
|
+
const int nth = 32*pipeline.nsg;
|
|
4447
|
+
|
|
4448
|
+
GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
4449
|
+
|
|
4450
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4451
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
4452
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
4453
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
4454
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
|
4455
|
+
|
|
4456
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
4457
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
4458
|
+
}
|
|
3156
4459
|
|
|
3157
4460
|
return 1;
|
|
3158
4461
|
}
|