whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
203
203
|
GGML_ABORT("unsupported op");
|
|
204
204
|
}
|
|
205
205
|
|
|
206
|
+
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
|
207
|
+
return 1;
|
|
208
|
+
}
|
|
209
|
+
|
|
206
210
|
int n_fuse = 1;
|
|
207
211
|
|
|
208
212
|
// check if the current node can run concurrently with other nodes before it
|
|
@@ -283,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
283
287
|
n_fuse = ggml_metal_op_acc(ctx, idx);
|
|
284
288
|
} break;
|
|
285
289
|
case GGML_OP_SCALE:
|
|
286
|
-
{
|
|
287
|
-
n_fuse = ggml_metal_op_scale(ctx, idx);
|
|
288
|
-
} break;
|
|
289
290
|
case GGML_OP_FILL:
|
|
290
|
-
{
|
|
291
|
-
n_fuse = ggml_metal_op_fill(ctx, idx);
|
|
292
|
-
} break;
|
|
293
291
|
case GGML_OP_CLAMP:
|
|
294
|
-
|
|
295
|
-
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
|
296
|
-
} break;
|
|
292
|
+
case GGML_OP_LEAKY_RELU:
|
|
297
293
|
case GGML_OP_SQR:
|
|
298
294
|
case GGML_OP_SQRT:
|
|
299
295
|
case GGML_OP_SIN:
|
|
@@ -337,6 +333,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
337
333
|
{
|
|
338
334
|
n_fuse = ggml_metal_op_rwkv(ctx, idx);
|
|
339
335
|
} break;
|
|
336
|
+
case GGML_OP_GATED_DELTA_NET:
|
|
337
|
+
{
|
|
338
|
+
n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
|
|
339
|
+
} break;
|
|
340
|
+
case GGML_OP_SOLVE_TRI:
|
|
341
|
+
{
|
|
342
|
+
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
|
|
343
|
+
} break;
|
|
340
344
|
case GGML_OP_MUL_MAT:
|
|
341
345
|
{
|
|
342
346
|
n_fuse = ggml_metal_op_mul_mat(ctx, idx);
|
|
@@ -353,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
353
357
|
{
|
|
354
358
|
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
|
355
359
|
} break;
|
|
360
|
+
case GGML_OP_DIAG:
|
|
361
|
+
{
|
|
362
|
+
n_fuse = ggml_metal_op_diag(ctx, idx);
|
|
363
|
+
} break;
|
|
356
364
|
case GGML_OP_L2_NORM:
|
|
357
365
|
{
|
|
358
366
|
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
|
@@ -414,10 +422,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
414
422
|
{
|
|
415
423
|
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
|
416
424
|
} break;
|
|
417
|
-
case GGML_OP_LEAKY_RELU:
|
|
418
|
-
{
|
|
419
|
-
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
|
420
|
-
} break;
|
|
421
425
|
case GGML_OP_TRI:
|
|
422
426
|
{
|
|
423
427
|
n_fuse = ggml_metal_op_tri(ctx, idx);
|
|
@@ -426,12 +430,20 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
426
430
|
{
|
|
427
431
|
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
|
428
432
|
} break;
|
|
433
|
+
case GGML_OP_SET:
|
|
434
|
+
{
|
|
435
|
+
n_fuse = ggml_metal_op_set(ctx, idx);
|
|
436
|
+
} break;
|
|
429
437
|
case GGML_OP_DUP:
|
|
430
438
|
case GGML_OP_CPY:
|
|
431
439
|
case GGML_OP_CONT:
|
|
432
440
|
{
|
|
433
441
|
n_fuse = ggml_metal_op_cpy(ctx, idx);
|
|
434
442
|
} break;
|
|
443
|
+
case GGML_OP_POOL_1D:
|
|
444
|
+
{
|
|
445
|
+
n_fuse = ggml_metal_op_pool_1d(ctx, idx);
|
|
446
|
+
} break;
|
|
435
447
|
case GGML_OP_POOL_2D:
|
|
436
448
|
{
|
|
437
449
|
n_fuse = ggml_metal_op_pool_2d(ctx, idx);
|
|
@@ -612,8 +624,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
612
624
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
613
625
|
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
614
626
|
|
|
615
|
-
GGML_ASSERT(
|
|
616
|
-
GGML_ASSERT(
|
|
627
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
628
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
|
617
629
|
|
|
618
630
|
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
|
619
631
|
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
|
@@ -623,7 +635,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
623
635
|
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
|
624
636
|
|
|
625
637
|
if (!inplace) {
|
|
626
|
-
// run a
|
|
638
|
+
// run a separate kernel to cpy src->dst
|
|
627
639
|
// not sure how to avoid this
|
|
628
640
|
// TODO: make a simpler cpy_bytes kernel
|
|
629
641
|
|
|
@@ -663,10 +675,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
663
675
|
}
|
|
664
676
|
|
|
665
677
|
ggml_metal_kargs_bin args = {
|
|
666
|
-
/*.ne00 =*/
|
|
667
|
-
/*.ne01 =*/
|
|
668
|
-
/*.ne02 =*/
|
|
669
|
-
/*.ne03 =*/
|
|
678
|
+
/*.ne00 =*/ ne10,
|
|
679
|
+
/*.ne01 =*/ ne11,
|
|
680
|
+
/*.ne02 =*/ ne12,
|
|
681
|
+
/*.ne03 =*/ ne13,
|
|
670
682
|
/*.nb00 =*/ nb00,
|
|
671
683
|
/*.nb01 =*/ pnb1,
|
|
672
684
|
/*.nb02 =*/ pnb2,
|
|
@@ -679,10 +691,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
679
691
|
/*.nb11 =*/ nb11,
|
|
680
692
|
/*.nb12 =*/ nb12,
|
|
681
693
|
/*.nb13 =*/ nb13,
|
|
682
|
-
/*.ne0 =*/
|
|
683
|
-
/*.ne1 =*/
|
|
684
|
-
/*.ne2 =*/
|
|
685
|
-
/*.ne3 =*/
|
|
694
|
+
/*.ne0 =*/ ne10,
|
|
695
|
+
/*.ne1 =*/ ne11,
|
|
696
|
+
/*.ne2 =*/ ne12,
|
|
697
|
+
/*.ne3 =*/ ne13,
|
|
686
698
|
/*.nb0 =*/ nb0,
|
|
687
699
|
/*.nb1 =*/ pnb1,
|
|
688
700
|
/*.nb2 =*/ pnb2,
|
|
@@ -691,7 +703,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
691
703
|
/*.o1 =*/ { 0 },
|
|
692
704
|
};
|
|
693
705
|
|
|
694
|
-
auto pipeline =
|
|
706
|
+
auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
|
|
695
707
|
|
|
696
708
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
697
709
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -699,53 +711,20 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
699
711
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
700
712
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
701
713
|
|
|
702
|
-
const int
|
|
703
|
-
|
|
704
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
|
705
|
-
|
|
706
|
-
return 1;
|
|
707
|
-
}
|
|
708
|
-
|
|
709
|
-
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
|
710
|
-
ggml_tensor * op = ctx->node(idx);
|
|
711
|
-
|
|
712
|
-
ggml_metal_library_t lib = ctx->lib;
|
|
713
|
-
ggml_metal_encoder_t enc = ctx->enc;
|
|
714
|
-
|
|
715
|
-
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
716
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
717
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
718
|
-
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
719
|
-
|
|
720
|
-
float scale;
|
|
721
|
-
float bias;
|
|
722
|
-
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
|
723
|
-
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
|
724
|
-
|
|
725
|
-
ggml_metal_kargs_scale args = {
|
|
726
|
-
/*.scale =*/ scale,
|
|
727
|
-
/*.bias =*/ bias,
|
|
728
|
-
};
|
|
714
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
729
715
|
|
|
730
|
-
|
|
716
|
+
int nth = 1;
|
|
731
717
|
|
|
732
|
-
|
|
733
|
-
|
|
718
|
+
while (2*nth < args.ne0 && nth < nth_max) {
|
|
719
|
+
nth *= 2;
|
|
734
720
|
}
|
|
735
721
|
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
739
|
-
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
740
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
741
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
742
|
-
|
|
743
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
722
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
|
744
723
|
|
|
745
724
|
return 1;
|
|
746
725
|
}
|
|
747
726
|
|
|
748
|
-
int
|
|
727
|
+
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|
749
728
|
ggml_tensor * op = ctx->node(idx);
|
|
750
729
|
|
|
751
730
|
ggml_metal_library_t lib = ctx->lib;
|
|
@@ -756,94 +735,80 @@ int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
|
|
|
756
735
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
757
736
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
758
737
|
|
|
759
|
-
|
|
738
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
760
739
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
};
|
|
740
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
741
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
764
742
|
|
|
765
|
-
|
|
743
|
+
ggml_metal_kargs_unary args = {
|
|
744
|
+
/*.ne00 =*/ ne00,
|
|
745
|
+
/*.ne01 =*/ ne01,
|
|
746
|
+
/*.ne02 =*/ ne02,
|
|
747
|
+
/*.ne03 =*/ ne03,
|
|
748
|
+
/*.nb00 =*/ nb00,
|
|
749
|
+
/*.nb01 =*/ nb01,
|
|
750
|
+
/*.nb02 =*/ nb02,
|
|
751
|
+
/*.nb03 =*/ nb03,
|
|
752
|
+
/*.ne0 =*/ ne0,
|
|
753
|
+
/*.ne1 =*/ ne1,
|
|
754
|
+
/*.ne2 =*/ ne2,
|
|
755
|
+
/*.ne3 =*/ ne3,
|
|
756
|
+
/*.nb0 =*/ nb0,
|
|
757
|
+
/*.nb1 =*/ nb1,
|
|
758
|
+
/*.nb2 =*/ nb2,
|
|
759
|
+
/*.nb3 =*/ nb3,
|
|
760
|
+
/*.slope =*/ 0.0,
|
|
761
|
+
/*.scale =*/ 0.0,
|
|
762
|
+
/*.bias =*/ 0.0,
|
|
763
|
+
/*.val =*/ 0.0,
|
|
764
|
+
/*.min =*/ 0.0,
|
|
765
|
+
/*.max =*/ 0.0,
|
|
766
|
+
};
|
|
766
767
|
|
|
767
|
-
if (
|
|
768
|
-
|
|
768
|
+
if (op->op == GGML_OP_LEAKY_RELU) {
|
|
769
|
+
args.slope = ggml_get_op_params_f32(op, 0);
|
|
769
770
|
}
|
|
770
771
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
776
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
777
|
-
|
|
778
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
779
|
-
|
|
780
|
-
return 1;
|
|
781
|
-
}
|
|
782
|
-
|
|
783
|
-
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
|
784
|
-
ggml_tensor * op = ctx->node(idx);
|
|
785
|
-
|
|
786
|
-
ggml_metal_library_t lib = ctx->lib;
|
|
787
|
-
ggml_metal_encoder_t enc = ctx->enc;
|
|
788
|
-
|
|
789
|
-
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
790
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
791
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
792
|
-
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
793
|
-
|
|
794
|
-
float min;
|
|
795
|
-
float max;
|
|
796
|
-
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
|
797
|
-
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
|
798
|
-
|
|
799
|
-
ggml_metal_kargs_clamp args = {
|
|
800
|
-
/*.min =*/ min,
|
|
801
|
-
/*.max =*/ max,
|
|
802
|
-
};
|
|
772
|
+
if (op->op == GGML_OP_SCALE) {
|
|
773
|
+
args.scale = ggml_get_op_params_f32(op, 0);
|
|
774
|
+
args.bias = ggml_get_op_params_f32(op, 1);
|
|
775
|
+
}
|
|
803
776
|
|
|
804
|
-
|
|
777
|
+
if (op->op == GGML_OP_FILL) {
|
|
778
|
+
args.val = ggml_get_op_params_f32(op, 0);
|
|
779
|
+
}
|
|
805
780
|
|
|
806
|
-
if (
|
|
807
|
-
|
|
781
|
+
if (op->op == GGML_OP_CLAMP) {
|
|
782
|
+
args.min = ggml_get_op_params_f32(op, 0);
|
|
783
|
+
args.max = ggml_get_op_params_f32(op, 1);
|
|
808
784
|
}
|
|
809
785
|
|
|
810
786
|
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
811
787
|
|
|
788
|
+
if (pipeline.c4) {
|
|
789
|
+
args.ne00 = ne00/4;
|
|
790
|
+
args.ne0 = ne0/4;
|
|
791
|
+
}
|
|
792
|
+
|
|
812
793
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
813
794
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
814
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
815
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
816
|
-
|
|
817
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
818
|
-
|
|
819
|
-
return 1;
|
|
820
|
-
}
|
|
795
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
796
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
821
797
|
|
|
822
|
-
|
|
823
|
-
|
|
798
|
+
if (pipeline.cnt) {
|
|
799
|
+
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
|
824
800
|
|
|
825
|
-
|
|
826
|
-
|
|
801
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
802
|
+
} else {
|
|
803
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
827
804
|
|
|
828
|
-
|
|
829
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
830
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
831
|
-
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
805
|
+
const int nth = MIN(args.ne00, nth_max);
|
|
832
806
|
|
|
833
|
-
|
|
807
|
+
const int nk0 = (args.ne00 + nth - 1)/nth;
|
|
834
808
|
|
|
835
|
-
|
|
836
|
-
n /= 4;
|
|
809
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
|
|
837
810
|
}
|
|
838
811
|
|
|
839
|
-
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
840
|
-
|
|
841
|
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
842
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
|
843
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
|
844
|
-
|
|
845
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
846
|
-
|
|
847
812
|
return 1;
|
|
848
813
|
}
|
|
849
814
|
|
|
@@ -953,6 +918,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
953
918
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
954
919
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
955
920
|
|
|
921
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
922
|
+
|
|
923
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
924
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
925
|
+
|
|
956
926
|
ggml_metal_kargs_sum_rows args = {
|
|
957
927
|
/*.ne00 =*/ ne00,
|
|
958
928
|
/*.ne01 =*/ ne01,
|
|
@@ -974,21 +944,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
974
944
|
|
|
975
945
|
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
|
976
946
|
|
|
947
|
+
if (pipeline.c4) {
|
|
948
|
+
args.ne00 = ne00/4;
|
|
949
|
+
args.ne0 = ne0/4;
|
|
950
|
+
}
|
|
951
|
+
|
|
977
952
|
int nth = 32; // SIMD width
|
|
978
953
|
|
|
979
|
-
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
954
|
+
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
980
955
|
nth *= 2;
|
|
981
956
|
}
|
|
982
957
|
|
|
983
958
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
984
|
-
nth = std::min(nth, ne00);
|
|
959
|
+
nth = std::min(nth, (int) args.ne00);
|
|
985
960
|
|
|
986
961
|
const size_t smem = pipeline.smem;
|
|
987
962
|
|
|
988
963
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
989
964
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
990
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
991
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
965
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
966
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
992
967
|
|
|
993
968
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
994
969
|
|
|
@@ -1247,6 +1222,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
1247
1222
|
return 1;
|
|
1248
1223
|
}
|
|
1249
1224
|
|
|
1225
|
+
int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
|
|
1226
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1227
|
+
|
|
1228
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1229
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1230
|
+
|
|
1231
|
+
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
|
1232
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1233
|
+
GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
|
|
1234
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1235
|
+
|
|
1236
|
+
ggml_metal_kargs_diag args = {
|
|
1237
|
+
/*.ne00 =*/ne00,
|
|
1238
|
+
/*.ne01 =*/ne01,
|
|
1239
|
+
/*.ne02 =*/ne02,
|
|
1240
|
+
/*.ne03 =*/ne03,
|
|
1241
|
+
/*.nb00 =*/nb00,
|
|
1242
|
+
/*.nb01 =*/nb01,
|
|
1243
|
+
/*.nb02 =*/nb02,
|
|
1244
|
+
/*.nb03 =*/nb03,
|
|
1245
|
+
/*.ne0 =*/ne0,
|
|
1246
|
+
/*.ne1 =*/ne1,
|
|
1247
|
+
/*.ne2 =*/ne2,
|
|
1248
|
+
/*.ne3 =*/ne3,
|
|
1249
|
+
/*.nb0 =*/nb0,
|
|
1250
|
+
/*.nb1 =*/nb1,
|
|
1251
|
+
/*.nb2 =*/nb2,
|
|
1252
|
+
/*.nb3 =*/nb3,
|
|
1253
|
+
};
|
|
1254
|
+
|
|
1255
|
+
auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
|
|
1256
|
+
|
|
1257
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1258
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1259
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1260
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
|
|
1261
|
+
|
|
1262
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
|
|
1263
|
+
|
|
1264
|
+
return 1;
|
|
1265
|
+
}
|
|
1266
|
+
|
|
1250
1267
|
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
1251
1268
|
ggml_tensor * op = ctx->node(idx);
|
|
1252
1269
|
|
|
@@ -1508,7 +1525,180 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|
|
1508
1525
|
return 1;
|
|
1509
1526
|
}
|
|
1510
1527
|
|
|
1511
|
-
int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|
1528
|
+
int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|
1529
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1530
|
+
|
|
1531
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1532
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1533
|
+
|
|
1534
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1535
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1536
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1537
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1538
|
+
|
|
1539
|
+
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
|
1540
|
+
const int64_t T = op->src[0]->ne[2];
|
|
1541
|
+
const int64_t C = op->ne[0];
|
|
1542
|
+
const int64_t H = op->src[0]->ne[1];
|
|
1543
|
+
|
|
1544
|
+
auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
|
1545
|
+
|
|
1546
|
+
int ida = 0;
|
|
1547
|
+
|
|
1548
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1549
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
1550
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
1551
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
1552
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
|
1553
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
|
1554
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
|
|
1555
|
+
if (op->op == GGML_OP_RWKV_WKV7) {
|
|
1556
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
|
|
1557
|
+
}
|
|
1558
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
|
|
1559
|
+
ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
|
|
1560
|
+
ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
|
|
1561
|
+
ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
|
|
1562
|
+
ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
|
|
1563
|
+
|
|
1564
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
|
|
1565
|
+
|
|
1566
|
+
return 1;
|
|
1567
|
+
}
|
|
1568
|
+
|
|
1569
|
+
int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
|
|
1570
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1571
|
+
|
|
1572
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1573
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1574
|
+
|
|
1575
|
+
|
|
1576
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1577
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1578
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1579
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1580
|
+
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1581
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1582
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1583
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1584
|
+
|
|
1585
|
+
auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
|
|
1586
|
+
|
|
1587
|
+
int ida = 0;
|
|
1588
|
+
|
|
1589
|
+
ggml_metal_kargs_gated_delta_net args = {
|
|
1590
|
+
/*.ne00 =*/ ne00,
|
|
1591
|
+
/*.ne01 =*/ ne01,
|
|
1592
|
+
/*.ne02 =*/ ne02,
|
|
1593
|
+
/*.ne03 =*/ ne03,
|
|
1594
|
+
/*.nb00 =*/ nb00,
|
|
1595
|
+
/*.nb01 =*/ nb01,
|
|
1596
|
+
/*.nb02 =*/ nb02,
|
|
1597
|
+
/*.nb03 =*/ nb03,
|
|
1598
|
+
/*.ne10 =*/ ne10,
|
|
1599
|
+
/*.ne11 =*/ ne11,
|
|
1600
|
+
/*.ne12 =*/ ne12,
|
|
1601
|
+
/*.ne13 =*/ ne13,
|
|
1602
|
+
/*.nb10 =*/ nb10,
|
|
1603
|
+
/*.nb11 =*/ nb11,
|
|
1604
|
+
/*.nb12 =*/ nb12,
|
|
1605
|
+
/*.nb13 =*/ nb13,
|
|
1606
|
+
/*.ne20 =*/ ne20,
|
|
1607
|
+
/*.ne21 =*/ ne21,
|
|
1608
|
+
/*.ne22 =*/ ne22,
|
|
1609
|
+
/*.ne23 =*/ ne23,
|
|
1610
|
+
/*.nb20 =*/ nb20,
|
|
1611
|
+
/*.nb21 =*/ nb21,
|
|
1612
|
+
/*.nb22 =*/ nb22,
|
|
1613
|
+
/*.nb23 =*/ nb23,
|
|
1614
|
+
/*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
|
|
1615
|
+
/*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
|
|
1616
|
+
/*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
|
|
1617
|
+
/*.ne0 =*/ ne0,
|
|
1618
|
+
/*.ne1 =*/ ne1,
|
|
1619
|
+
/*.ne2 =*/ ne2,
|
|
1620
|
+
/*.ne3 =*/ ne3,
|
|
1621
|
+
/*.nb0 =*/ nb0,
|
|
1622
|
+
/*.nb1 =*/ nb1,
|
|
1623
|
+
/*.nb2 =*/ nb2,
|
|
1624
|
+
/*.nb3 =*/ nb3,
|
|
1625
|
+
};
|
|
1626
|
+
|
|
1627
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1628
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
1629
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
|
|
1630
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
|
|
1631
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
|
|
1632
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
|
|
1633
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
|
|
1634
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
|
|
1635
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
|
|
1636
|
+
|
|
1637
|
+
const int nsg = pipeline.nsg;
|
|
1638
|
+
|
|
1639
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
|
|
1640
|
+
|
|
1641
|
+
return 1;
|
|
1642
|
+
}
|
|
1643
|
+
|
|
1644
|
+
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
|
1645
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1646
|
+
|
|
1647
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1648
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1649
|
+
|
|
1650
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1651
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1652
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1653
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1654
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1655
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1656
|
+
|
|
1657
|
+
ggml_metal_kargs_solve_tri args = {
|
|
1658
|
+
/*.ne00 =*/ ne00,
|
|
1659
|
+
/*.ne01 =*/ ne01,
|
|
1660
|
+
/*.ne02 =*/ ne02,
|
|
1661
|
+
/*.ne03 =*/ ne03,
|
|
1662
|
+
/*.nb00 =*/ nb00,
|
|
1663
|
+
/*.nb01 =*/ nb01,
|
|
1664
|
+
/*.nb02 =*/ nb02,
|
|
1665
|
+
/*.nb03 =*/ nb03,
|
|
1666
|
+
/*.ne10 =*/ ne10,
|
|
1667
|
+
/*.ne11 =*/ ne11,
|
|
1668
|
+
/*.ne12 =*/ ne12,
|
|
1669
|
+
/*.ne13 =*/ ne13,
|
|
1670
|
+
/*.nb10 =*/ nb10,
|
|
1671
|
+
/*.nb11 =*/ nb11,
|
|
1672
|
+
/*.nb12 =*/ nb12,
|
|
1673
|
+
/*.nb13 =*/ nb13,
|
|
1674
|
+
/*.ne0 =*/ ne0,
|
|
1675
|
+
/*.ne1 =*/ ne1,
|
|
1676
|
+
/*.ne2 =*/ ne2,
|
|
1677
|
+
/*.ne3 =*/ ne3,
|
|
1678
|
+
/*.nb0 =*/ nb0,
|
|
1679
|
+
/*.nb1 =*/ nb1,
|
|
1680
|
+
/*.nb2 =*/ nb2,
|
|
1681
|
+
/*.nb3 =*/ nb3,
|
|
1682
|
+
};
|
|
1683
|
+
|
|
1684
|
+
auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
|
|
1685
|
+
|
|
1686
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1687
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1688
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1689
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1690
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
1691
|
+
|
|
1692
|
+
const int nsg = pipeline.nsg;
|
|
1693
|
+
|
|
1694
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
|
|
1695
|
+
|
|
1696
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
|
|
1697
|
+
|
|
1698
|
+
return 1;
|
|
1699
|
+
}
|
|
1700
|
+
|
|
1701
|
+
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
|
1512
1702
|
ggml_tensor * op = ctx->node(idx);
|
|
1513
1703
|
|
|
1514
1704
|
ggml_metal_library_t lib = ctx->lib;
|
|
@@ -1516,35 +1706,122 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|
|
1516
1706
|
|
|
1517
1707
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1518
1708
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1709
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1710
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1519
1711
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1520
1712
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1521
1713
|
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
const int64_t H = op->src[0]->ne[1];
|
|
1714
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
1715
|
+
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
1716
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
1526
1717
|
|
|
1527
|
-
|
|
1718
|
+
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
|
1719
|
+
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
|
1720
|
+
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
|
|
1721
|
+
const size_t offs = ((const int32_t *) op->op_params)[3];
|
|
1528
1722
|
|
|
1529
|
-
|
|
1723
|
+
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
|
1530
1724
|
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1725
|
+
if (!inplace) {
|
|
1726
|
+
// run a separate kernel to cpy src->dst
|
|
1727
|
+
// not sure how to avoid this
|
|
1728
|
+
// TODO: make a simpler cpy_bytes kernel
|
|
1729
|
+
|
|
1730
|
+
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
|
1731
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
1732
|
+
|
|
1733
|
+
ggml_metal_kargs_cpy args = {
|
|
1734
|
+
/*.nk0 =*/ ne00,
|
|
1735
|
+
/*.ne00 =*/ ne00,
|
|
1736
|
+
/*.ne01 =*/ ne01,
|
|
1737
|
+
/*.ne02 =*/ ne02,
|
|
1738
|
+
/*.ne03 =*/ ne03,
|
|
1739
|
+
/*.nb00 =*/ nb00,
|
|
1740
|
+
/*.nb01 =*/ nb01,
|
|
1741
|
+
/*.nb02 =*/ nb02,
|
|
1742
|
+
/*.nb03 =*/ nb03,
|
|
1743
|
+
/*.ne0 =*/ ne0,
|
|
1744
|
+
/*.ne1 =*/ ne1,
|
|
1745
|
+
/*.ne2 =*/ ne2,
|
|
1746
|
+
/*.ne3 =*/ ne3,
|
|
1747
|
+
/*.nb0 =*/ nb0,
|
|
1748
|
+
/*.nb1 =*/ nb1,
|
|
1749
|
+
/*.nb2 =*/ nb2,
|
|
1750
|
+
/*.nb3 =*/ nb3,
|
|
1751
|
+
};
|
|
1752
|
+
|
|
1753
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1754
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1755
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
1756
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1757
|
+
|
|
1758
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
|
1759
|
+
|
|
1760
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
1761
|
+
|
|
1762
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
1540
1763
|
}
|
|
1541
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
|
|
1542
|
-
ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
|
|
1543
|
-
ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
|
|
1544
|
-
ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
|
|
1545
|
-
ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
|
|
1546
1764
|
|
|
1547
|
-
|
|
1765
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
|
|
1766
|
+
|
|
1767
|
+
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
|
|
1768
|
+
|
|
1769
|
+
int64_t nk0 = ne10;
|
|
1770
|
+
if (ggml_is_quantized(op->src[1]->type)) {
|
|
1771
|
+
nk0 = ne10/16;
|
|
1772
|
+
} else if (ggml_is_quantized(op->type)) {
|
|
1773
|
+
nk0 = ne10/ggml_blck_size(op->type);
|
|
1774
|
+
}
|
|
1775
|
+
|
|
1776
|
+
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1777
|
+
|
|
1778
|
+
// when rows are small, we can batch them together in a single threadgroup
|
|
1779
|
+
int nrptg = 1;
|
|
1780
|
+
|
|
1781
|
+
// TODO: relax this constraint in the future
|
|
1782
|
+
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
|
1783
|
+
if (nth > nk0) {
|
|
1784
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
1785
|
+
nth = nk0;
|
|
1786
|
+
|
|
1787
|
+
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1788
|
+
nrptg--;
|
|
1789
|
+
}
|
|
1790
|
+
}
|
|
1791
|
+
}
|
|
1792
|
+
|
|
1793
|
+
nth = std::min<int>(nth, nk0);
|
|
1794
|
+
|
|
1795
|
+
ggml_metal_kargs_cpy args = {
|
|
1796
|
+
/*.nk0 =*/ nk0,
|
|
1797
|
+
/*.ne00 =*/ ne10,
|
|
1798
|
+
/*.ne01 =*/ ne11,
|
|
1799
|
+
/*.ne02 =*/ ne12,
|
|
1800
|
+
/*.ne03 =*/ ne13,
|
|
1801
|
+
/*.nb00 =*/ nb10,
|
|
1802
|
+
/*.nb01 =*/ nb11,
|
|
1803
|
+
/*.nb02 =*/ nb12,
|
|
1804
|
+
/*.nb03 =*/ nb13,
|
|
1805
|
+
/*.ne0 =*/ ne10,
|
|
1806
|
+
/*.ne1 =*/ ne11,
|
|
1807
|
+
/*.ne2 =*/ ne12,
|
|
1808
|
+
/*.ne3 =*/ ne13,
|
|
1809
|
+
/*.nb0 =*/ ggml_element_size(op),
|
|
1810
|
+
/*.nb1 =*/ pnb1,
|
|
1811
|
+
/*.nb2 =*/ pnb2,
|
|
1812
|
+
/*.nb3 =*/ pnb3,
|
|
1813
|
+
};
|
|
1814
|
+
|
|
1815
|
+
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
|
1816
|
+
|
|
1817
|
+
bid_dst.offs += offs;
|
|
1818
|
+
|
|
1819
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1820
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1821
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
1822
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1823
|
+
|
|
1824
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
|
|
1548
1825
|
|
|
1549
1826
|
return 1;
|
|
1550
1827
|
}
|
|
@@ -1622,6 +1899,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1622
1899
|
return 1;
|
|
1623
1900
|
}
|
|
1624
1901
|
|
|
1902
|
+
int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
|
|
1903
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1904
|
+
|
|
1905
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1906
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1907
|
+
|
|
1908
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1909
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1910
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1911
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1912
|
+
|
|
1913
|
+
const int32_t * opts = op->op_params;
|
|
1914
|
+
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
|
1915
|
+
|
|
1916
|
+
const int32_t k0 = opts[1];
|
|
1917
|
+
const int32_t s0 = opts[2];
|
|
1918
|
+
const int32_t p0 = opts[3];
|
|
1919
|
+
|
|
1920
|
+
const int64_t IW = op->src[0]->ne[0];
|
|
1921
|
+
const int64_t OW = op->ne[0];
|
|
1922
|
+
|
|
1923
|
+
const int64_t np = ggml_nelements(op);
|
|
1924
|
+
|
|
1925
|
+
ggml_metal_kargs_pool_1d args_pool_1d = {
|
|
1926
|
+
/* .k0 = */ k0,
|
|
1927
|
+
/* .s0 = */ s0,
|
|
1928
|
+
/* .p0 = */ p0,
|
|
1929
|
+
/* .IW = */ IW,
|
|
1930
|
+
/* .OW = */ OW,
|
|
1931
|
+
/* .np = */ np
|
|
1932
|
+
};
|
|
1933
|
+
|
|
1934
|
+
auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
|
|
1935
|
+
|
|
1936
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
|
1937
|
+
const int ntg = (np + nth - 1) / nth;
|
|
1938
|
+
|
|
1939
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1940
|
+
ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
|
|
1941
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1942
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
1943
|
+
|
|
1944
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
|
|
1945
|
+
|
|
1946
|
+
return 1;
|
|
1947
|
+
}
|
|
1948
|
+
|
|
1949
|
+
|
|
1625
1950
|
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
|
1626
1951
|
ggml_tensor * op = ctx->node(idx);
|
|
1627
1952
|
|
|
@@ -1717,6 +2042,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1717
2042
|
(
|
|
1718
2043
|
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
|
1719
2044
|
op->src[0]->type == GGML_TYPE_F16 ||
|
|
2045
|
+
op->src[0]->type == GGML_TYPE_BF16 ||
|
|
1720
2046
|
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
|
1721
2047
|
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
|
1722
2048
|
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
|
@@ -1731,6 +2057,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1731
2057
|
op->src[0]->type == GGML_TYPE_Q4_K ||
|
|
1732
2058
|
op->src[0]->type == GGML_TYPE_Q5_K ||
|
|
1733
2059
|
op->src[0]->type == GGML_TYPE_Q6_K ||
|
|
2060
|
+
op->src[0]->type == GGML_TYPE_Q2_K ||
|
|
2061
|
+
op->src[0]->type == GGML_TYPE_Q3_K ||
|
|
1734
2062
|
false) && (ne11 >= 4 && ne11 <= 8)
|
|
1735
2063
|
)
|
|
1736
2064
|
)
|
|
@@ -1759,7 +2087,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1759
2087
|
const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
|
1760
2088
|
int16_t r1ptg = 4; // num src1 rows per threadgroup
|
|
1761
2089
|
|
|
1762
|
-
// note: not sure how optimal are those across all different hardware. there might be
|
|
2090
|
+
// note: not sure how optimal are those across all different hardware. there might be something cleverer
|
|
1763
2091
|
switch (ne11) {
|
|
1764
2092
|
case 2:
|
|
1765
2093
|
r1ptg = 2; break;
|
|
@@ -2239,7 +2567,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
|
|
2239
2567
|
// return res;
|
|
2240
2568
|
//}
|
|
2241
2569
|
|
|
2242
|
-
const int nqptg = is_vec ?
|
|
2570
|
+
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
|
|
2243
2571
|
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
|
2244
2572
|
|
|
2245
2573
|
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
|
@@ -2355,7 +2683,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2355
2683
|
|
|
2356
2684
|
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2357
2685
|
// half8x8 kernel
|
|
2358
|
-
const int nqptg =
|
|
2686
|
+
const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
|
|
2359
2687
|
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
|
2360
2688
|
|
|
2361
2689
|
GGML_ASSERT(nqptg <= 32);
|
|
@@ -2464,7 +2792,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2464
2792
|
|
|
2465
2793
|
// simdgroups per threadgroup (a.k.a. warps)
|
|
2466
2794
|
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
|
2467
|
-
int32_t nsg = 4;
|
|
2795
|
+
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
|
2468
2796
|
|
|
2469
2797
|
const size_t smem = FATTN_SMEM(nsg);
|
|
2470
2798
|
|
|
@@ -2522,9 +2850,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2522
2850
|
#undef FATTN_SMEM
|
|
2523
2851
|
} else {
|
|
2524
2852
|
// half4x4 kernel
|
|
2525
|
-
const int nqptg =
|
|
2853
|
+
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
|
|
2526
2854
|
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
2527
|
-
const int
|
|
2855
|
+
const int nhptg = 1; // heads per threadgroup
|
|
2528
2856
|
|
|
2529
2857
|
GGML_ASSERT(nqptg <= 32);
|
|
2530
2858
|
GGML_ASSERT(nqptg % 1 == 0);
|
|
@@ -2576,6 +2904,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2576
2904
|
ggml_metal_op_concurrency_reset(ctx);
|
|
2577
2905
|
}
|
|
2578
2906
|
|
|
2907
|
+
// note: for simplicity assume the K is larger or equal than V
|
|
2908
|
+
GGML_ASSERT(ne10 >= ne20);
|
|
2909
|
+
|
|
2579
2910
|
// ne00 + 2*ncpsg*(nsg)
|
|
2580
2911
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
2581
2912
|
// and store the soft_max values and the mask
|
|
@@ -2583,28 +2914,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2583
2914
|
// ne20*(nsg)
|
|
2584
2915
|
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
|
2585
2916
|
//
|
|
2586
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((
|
|
2587
|
-
|
|
2588
|
-
int64_t nsgmax = 2;
|
|
2589
|
-
while (true) {
|
|
2590
|
-
const size_t smem = FATTN_SMEM(nsgmax);
|
|
2591
|
-
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
|
|
2592
|
-
if (smem > props_dev->max_theadgroup_memory_size/2) {
|
|
2593
|
-
break;
|
|
2594
|
-
}
|
|
2595
|
-
nsgmax *= 2;
|
|
2596
|
-
}
|
|
2597
|
-
nsgmax /= 2;
|
|
2598
|
-
|
|
2599
|
-
// simdgroups per threadgroup (a.k.a. warps)
|
|
2600
|
-
//const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
|
2601
|
-
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
|
|
2917
|
+
#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
|
|
2602
2918
|
|
|
2603
2919
|
int64_t nsg = 1;
|
|
2604
|
-
while (nsg <= nsgt) {
|
|
2605
|
-
nsg *= 2;
|
|
2606
|
-
}
|
|
2607
|
-
nsg /= 2;
|
|
2608
2920
|
|
|
2609
2921
|
// workgroups
|
|
2610
2922
|
// each workgroup handles nsg*nkpsg cache values
|
|
@@ -2617,7 +2929,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2617
2929
|
} else {
|
|
2618
2930
|
nwg = 32;
|
|
2619
2931
|
nsg = 1;
|
|
2620
|
-
while (2*nwg*nsg*
|
|
2932
|
+
while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
|
|
2621
2933
|
nsg *= 2;
|
|
2622
2934
|
}
|
|
2623
2935
|
}
|
|
@@ -2683,7 +2995,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2683
2995
|
|
|
2684
2996
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2685
2997
|
|
|
2686
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
2998
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
|
|
2687
2999
|
} else {
|
|
2688
3000
|
// sanity checks
|
|
2689
3001
|
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
|
@@ -2696,7 +3008,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2696
3008
|
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
|
2697
3009
|
|
|
2698
3010
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2699
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
3011
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
|
|
2700
3012
|
|
|
2701
3013
|
// sync the 2 kernels
|
|
2702
3014
|
ggml_metal_op_concurrency_reset(ctx);
|
|
@@ -2748,8 +3060,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2748
3060
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
2749
3061
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
|
2750
3062
|
|
|
2751
|
-
bool bcast_row = false;
|
|
2752
|
-
|
|
2753
3063
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
2754
3064
|
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
2755
3065
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
@@ -2843,18 +3153,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2843
3153
|
|
|
2844
3154
|
struct ggml_metal_pipeline_with_params pipeline;
|
|
2845
3155
|
|
|
2846
|
-
|
|
2847
|
-
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
2848
|
-
|
|
2849
|
-
// src1 is a row
|
|
2850
|
-
GGML_ASSERT(ne11 == 1);
|
|
2851
|
-
|
|
2852
|
-
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
|
|
2853
|
-
|
|
2854
|
-
bcast_row = true;
|
|
2855
|
-
} else {
|
|
2856
|
-
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
|
|
2857
|
-
}
|
|
3156
|
+
pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
|
|
2858
3157
|
|
|
2859
3158
|
if (n_fuse > 1) {
|
|
2860
3159
|
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
|
@@ -2868,20 +3167,26 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2868
3167
|
}
|
|
2869
3168
|
}
|
|
2870
3169
|
|
|
3170
|
+
if (pipeline.c4) {
|
|
3171
|
+
args.ne00 = ne00/4;
|
|
3172
|
+
args.ne10 = ne10/4;
|
|
3173
|
+
args.ne0 = ne0/4;
|
|
3174
|
+
}
|
|
3175
|
+
|
|
2871
3176
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2872
3177
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2873
3178
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2874
3179
|
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2875
3180
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
|
2876
3181
|
|
|
2877
|
-
if (
|
|
2878
|
-
|
|
2879
|
-
|
|
2880
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
3182
|
+
if (pipeline.cnt) {
|
|
3183
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
|
|
2881
3184
|
} else {
|
|
2882
|
-
int
|
|
3185
|
+
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
3186
|
+
|
|
3187
|
+
int nth = 1;
|
|
2883
3188
|
|
|
2884
|
-
while (
|
|
3189
|
+
while (2*nth < args.ne0 && nth < nth_max) {
|
|
2885
3190
|
nth *= 2;
|
|
2886
3191
|
}
|
|
2887
3192
|
|
|
@@ -2902,39 +3207,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2902
3207
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2903
3208
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2904
3209
|
|
|
3210
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
3211
|
+
|
|
3212
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
3213
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
3214
|
+
|
|
2905
3215
|
float eps;
|
|
2906
3216
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
2907
3217
|
|
|
2908
|
-
int nth = 32; // SIMD width
|
|
2909
|
-
|
|
2910
3218
|
ggml_metal_kargs_l2_norm args = {
|
|
2911
|
-
/*.ne00
|
|
2912
|
-
/*.
|
|
2913
|
-
/*.
|
|
2914
|
-
/*.
|
|
3219
|
+
/*.ne00 =*/ ne00,
|
|
3220
|
+
/*.ne01 =*/ ne01,
|
|
3221
|
+
/*.ne02 =*/ ne02,
|
|
3222
|
+
/*.ne03 =*/ ne03,
|
|
3223
|
+
/*.nb00 =*/ nb00,
|
|
3224
|
+
/*.nb01 =*/ nb01,
|
|
3225
|
+
/*.nb02 =*/ nb02,
|
|
3226
|
+
/*.nb03 =*/ nb03,
|
|
3227
|
+
/*.ne0 =*/ ne0,
|
|
3228
|
+
/*.ne1 =*/ ne1,
|
|
3229
|
+
/*.ne2 =*/ ne2,
|
|
3230
|
+
/*.ne3 =*/ ne3,
|
|
3231
|
+
/*.nb0 =*/ nb0,
|
|
3232
|
+
/*.nb1 =*/ nb1,
|
|
3233
|
+
/*.nb2 =*/ nb2,
|
|
3234
|
+
/*.nb3 =*/ nb3,
|
|
3235
|
+
/*.eps =*/ eps,
|
|
2915
3236
|
};
|
|
2916
3237
|
|
|
2917
3238
|
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
|
2918
3239
|
|
|
2919
|
-
|
|
3240
|
+
if (pipeline.c4) {
|
|
3241
|
+
args.ne00 = ne00/4;
|
|
3242
|
+
args.ne0 = ne0/4;
|
|
3243
|
+
}
|
|
3244
|
+
|
|
3245
|
+
int nth = 32; // SIMD width
|
|
3246
|
+
|
|
3247
|
+
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
2920
3248
|
nth *= 2;
|
|
2921
3249
|
}
|
|
2922
3250
|
|
|
2923
3251
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2924
|
-
nth = std::min(nth, ne00/4);
|
|
2925
3252
|
|
|
2926
3253
|
const size_t smem = pipeline.smem;
|
|
2927
3254
|
|
|
2928
|
-
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
2929
|
-
|
|
2930
3255
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2931
3256
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2932
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2933
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
3257
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3258
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
2934
3259
|
|
|
2935
3260
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2936
3261
|
|
|
2937
|
-
ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
3262
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
2938
3263
|
|
|
2939
3264
|
return 1;
|
|
2940
3265
|
}
|
|
@@ -3484,32 +3809,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
|
|
3484
3809
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3485
3810
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3486
3811
|
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
|
|
3812
|
+
float sf0 = (float)ne0/op->src[0]->ne[0];
|
|
3813
|
+
float sf1 = (float)ne1/op->src[0]->ne[1];
|
|
3814
|
+
float sf2 = (float)ne2/op->src[0]->ne[2];
|
|
3815
|
+
float sf3 = (float)ne3/op->src[0]->ne[3];
|
|
3816
|
+
|
|
3817
|
+
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
|
|
3818
|
+
|
|
3819
|
+
float poffs = 0.5f;
|
|
3820
|
+
|
|
3821
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
3822
|
+
poffs = 0.0f;
|
|
3823
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
3824
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
3825
|
+
}
|
|
3491
3826
|
|
|
3492
3827
|
ggml_metal_kargs_upscale args = {
|
|
3493
|
-
/*.ne00
|
|
3494
|
-
/*.ne01
|
|
3495
|
-
/*.ne02
|
|
3496
|
-
/*.ne03
|
|
3497
|
-
/*.nb00
|
|
3498
|
-
/*.nb01
|
|
3499
|
-
/*.nb02
|
|
3500
|
-
/*.nb03
|
|
3501
|
-
/*.ne0
|
|
3502
|
-
/*.ne1
|
|
3503
|
-
/*.ne2
|
|
3504
|
-
/*.ne3
|
|
3505
|
-
/*.nb0
|
|
3506
|
-
/*.nb1
|
|
3507
|
-
/*.nb2
|
|
3508
|
-
/*.nb3
|
|
3509
|
-
/*.sf0
|
|
3510
|
-
/*.sf1
|
|
3511
|
-
/*.sf2
|
|
3512
|
-
/*.sf3
|
|
3828
|
+
/*.ne00 =*/ ne00,
|
|
3829
|
+
/*.ne01 =*/ ne01,
|
|
3830
|
+
/*.ne02 =*/ ne02,
|
|
3831
|
+
/*.ne03 =*/ ne03,
|
|
3832
|
+
/*.nb00 =*/ nb00,
|
|
3833
|
+
/*.nb01 =*/ nb01,
|
|
3834
|
+
/*.nb02 =*/ nb02,
|
|
3835
|
+
/*.nb03 =*/ nb03,
|
|
3836
|
+
/*.ne0 =*/ ne0,
|
|
3837
|
+
/*.ne1 =*/ ne1,
|
|
3838
|
+
/*.ne2 =*/ ne2,
|
|
3839
|
+
/*.ne3 =*/ ne3,
|
|
3840
|
+
/*.nb0 =*/ nb0,
|
|
3841
|
+
/*.nb1 =*/ nb1,
|
|
3842
|
+
/*.nb2 =*/ nb2,
|
|
3843
|
+
/*.nb3 =*/ nb3,
|
|
3844
|
+
/*.sf0 =*/ sf0,
|
|
3845
|
+
/*.sf1 =*/ sf1,
|
|
3846
|
+
/*.sf2 =*/ sf2,
|
|
3847
|
+
/*.sf3 =*/ sf3,
|
|
3848
|
+
/*.poffs =*/ poffs,
|
|
3513
3849
|
};
|
|
3514
3850
|
|
|
3515
3851
|
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
|
@@ -3942,42 +4278,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
|
|
3942
4278
|
return 1;
|
|
3943
4279
|
}
|
|
3944
4280
|
|
|
3945
|
-
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
|
3946
|
-
ggml_tensor * op = ctx->node(idx);
|
|
3947
|
-
|
|
3948
|
-
ggml_metal_library_t lib = ctx->lib;
|
|
3949
|
-
ggml_metal_encoder_t enc = ctx->enc;
|
|
3950
|
-
|
|
3951
|
-
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3952
|
-
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3953
|
-
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3954
|
-
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3955
|
-
|
|
3956
|
-
float slope;
|
|
3957
|
-
memcpy(&slope, op->op_params, sizeof(float));
|
|
3958
|
-
|
|
3959
|
-
ggml_metal_kargs_leaky_relu args = {
|
|
3960
|
-
/*.slope =*/ slope
|
|
3961
|
-
};
|
|
3962
|
-
|
|
3963
|
-
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
3964
|
-
|
|
3965
|
-
int64_t n = ggml_nelements(op);
|
|
3966
|
-
|
|
3967
|
-
if (n % 4 == 0) {
|
|
3968
|
-
n /= 4;
|
|
3969
|
-
}
|
|
3970
|
-
|
|
3971
|
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3972
|
-
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3973
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3974
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
3975
|
-
|
|
3976
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
3977
|
-
|
|
3978
|
-
return 1;
|
|
3979
|
-
}
|
|
3980
|
-
|
|
3981
4281
|
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
|
3982
4282
|
ggml_tensor * op = ctx->node(idx);
|
|
3983
4283
|
|