whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -50,14 +50,14 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
|
|
|
50
50
|
}
|
|
51
51
|
|
|
52
52
|
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
|
|
53
|
-
if
|
|
53
|
+
if (ppls->data.find(name) == ppls->data.end()) {
|
|
54
54
|
return nullptr;
|
|
55
55
|
}
|
|
56
56
|
|
|
57
57
|
return ppls->data[name];
|
|
58
58
|
}
|
|
59
59
|
|
|
60
|
-
|
|
60
|
+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
|
|
61
61
|
char base[256];
|
|
62
62
|
char name[256];
|
|
63
63
|
|
|
@@ -71,34 +71,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t
|
|
|
71
71
|
snprintf(base, 256, "kernel_%s", op_str);
|
|
72
72
|
snprintf(name, 256, "%s", base);
|
|
73
73
|
|
|
74
|
-
|
|
75
|
-
if (res) {
|
|
76
|
-
|
|
74
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
75
|
+
if (!res.pipeline) {
|
|
76
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
77
77
|
}
|
|
78
78
|
|
|
79
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
80
|
-
|
|
81
79
|
return res;
|
|
82
80
|
}
|
|
83
81
|
|
|
84
|
-
|
|
82
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
|
|
85
83
|
char base[256];
|
|
86
84
|
char name[256];
|
|
87
85
|
|
|
88
86
|
snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
|
|
89
87
|
snprintf(name, 256, "%s", base);
|
|
90
88
|
|
|
91
|
-
|
|
92
|
-
if (res) {
|
|
93
|
-
|
|
89
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
90
|
+
if (!res.pipeline) {
|
|
91
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
94
92
|
}
|
|
95
93
|
|
|
96
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
97
|
-
|
|
98
94
|
return res;
|
|
99
95
|
}
|
|
100
96
|
|
|
101
|
-
|
|
97
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
|
|
102
98
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
103
99
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
|
|
104
100
|
|
|
@@ -115,68 +111,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library
|
|
|
115
111
|
snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
|
|
116
112
|
snprintf(name, 256, "%s", base);
|
|
117
113
|
|
|
118
|
-
|
|
119
|
-
if (res) {
|
|
120
|
-
|
|
114
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
115
|
+
if (!res.pipeline) {
|
|
116
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
121
117
|
}
|
|
122
118
|
|
|
123
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
124
|
-
|
|
125
119
|
return res;
|
|
126
120
|
}
|
|
127
121
|
|
|
128
|
-
|
|
122
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
|
|
129
123
|
char base[256];
|
|
130
124
|
char name[256];
|
|
131
125
|
|
|
132
126
|
snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
|
|
133
127
|
snprintf(name, 256, "%s", base);
|
|
134
128
|
|
|
135
|
-
|
|
136
|
-
if (res) {
|
|
137
|
-
|
|
129
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
130
|
+
if (!res.pipeline) {
|
|
131
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
138
132
|
}
|
|
139
133
|
|
|
140
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
141
|
-
|
|
142
134
|
return res;
|
|
143
135
|
}
|
|
144
136
|
|
|
145
|
-
|
|
137
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
|
|
146
138
|
char base[256];
|
|
147
139
|
char name[256];
|
|
148
140
|
|
|
149
141
|
snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
|
|
150
142
|
snprintf(name, 256, "%s", base);
|
|
151
143
|
|
|
152
|
-
|
|
153
|
-
if (res) {
|
|
154
|
-
|
|
144
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
145
|
+
if (!res.pipeline) {
|
|
146
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
155
147
|
}
|
|
156
148
|
|
|
157
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
158
|
-
|
|
159
149
|
return res;
|
|
160
150
|
}
|
|
161
151
|
|
|
162
|
-
|
|
152
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
|
|
163
153
|
char base[256];
|
|
164
154
|
char name[256];
|
|
165
155
|
|
|
166
156
|
snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
|
|
167
157
|
snprintf(name, 256, "%s", base);
|
|
168
158
|
|
|
169
|
-
|
|
170
|
-
if (res) {
|
|
171
|
-
|
|
159
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
160
|
+
if (!res.pipeline) {
|
|
161
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
172
162
|
}
|
|
173
163
|
|
|
174
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
175
|
-
|
|
176
164
|
return res;
|
|
177
165
|
}
|
|
178
166
|
|
|
179
|
-
|
|
167
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
180
168
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
181
169
|
|
|
182
170
|
char base[256];
|
|
@@ -187,6 +175,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
|
|
|
187
175
|
const char * op_str = "undefined";
|
|
188
176
|
switch (op->op) {
|
|
189
177
|
case GGML_OP_SCALE: op_str = "scale"; break;
|
|
178
|
+
case GGML_OP_FILL: op_str = "fill"; break;
|
|
190
179
|
case GGML_OP_CLAMP: op_str = "clamp"; break;
|
|
191
180
|
case GGML_OP_SQR: op_str = "sqr"; break;
|
|
192
181
|
case GGML_OP_SQRT: op_str = "sqrt"; break;
|
|
@@ -211,6 +200,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
|
|
|
211
200
|
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
|
|
212
201
|
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
|
|
213
202
|
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
|
|
203
|
+
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
|
|
204
|
+
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
|
|
214
205
|
default: GGML_ABORT("fatal error");
|
|
215
206
|
} break;
|
|
216
207
|
default: GGML_ABORT("fatal error");
|
|
@@ -224,17 +215,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
|
|
|
224
215
|
snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
|
|
225
216
|
snprintf(name, 256, "%s", base);
|
|
226
217
|
|
|
227
|
-
|
|
228
|
-
if (res) {
|
|
229
|
-
|
|
218
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
219
|
+
if (!res.pipeline) {
|
|
220
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
230
221
|
}
|
|
231
222
|
|
|
232
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
233
|
-
|
|
234
223
|
return res;
|
|
235
224
|
}
|
|
236
225
|
|
|
237
|
-
|
|
226
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
238
227
|
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
|
239
228
|
|
|
240
229
|
char base[256];
|
|
@@ -258,17 +247,32 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
|
|
|
258
247
|
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
|
|
259
248
|
snprintf(name, 256, "%s", base);
|
|
260
249
|
|
|
261
|
-
|
|
262
|
-
if (res) {
|
|
263
|
-
|
|
250
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
251
|
+
if (!res.pipeline) {
|
|
252
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
264
253
|
}
|
|
265
254
|
|
|
266
|
-
res
|
|
255
|
+
return res;
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
259
|
+
assert(op->op == GGML_OP_SUM);
|
|
260
|
+
|
|
261
|
+
char base[256];
|
|
262
|
+
char name[256];
|
|
263
|
+
|
|
264
|
+
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
|
|
265
|
+
snprintf(name, 256, "%s", base);
|
|
266
|
+
|
|
267
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
268
|
+
if (!res.pipeline) {
|
|
269
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
270
|
+
}
|
|
267
271
|
|
|
268
272
|
return res;
|
|
269
273
|
}
|
|
270
274
|
|
|
271
|
-
|
|
275
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
272
276
|
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
|
273
277
|
|
|
274
278
|
char base[256];
|
|
@@ -287,19 +291,73 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
|
|
|
287
291
|
|
|
288
292
|
snprintf(name, 256, "%s", base);
|
|
289
293
|
|
|
290
|
-
|
|
291
|
-
if (res) {
|
|
292
|
-
|
|
294
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
295
|
+
if (!res.pipeline) {
|
|
296
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
293
297
|
}
|
|
294
298
|
|
|
295
|
-
res =
|
|
299
|
+
res.smem = 32*sizeof(float);
|
|
296
300
|
|
|
297
|
-
|
|
301
|
+
return res;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
305
|
+
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
|
306
|
+
|
|
307
|
+
char base[256];
|
|
308
|
+
char name[256];
|
|
309
|
+
|
|
310
|
+
snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
|
|
311
|
+
snprintf(name, 256, "%s", base);
|
|
312
|
+
|
|
313
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
314
|
+
if (!res.pipeline) {
|
|
315
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
316
|
+
}
|
|
298
317
|
|
|
299
318
|
return res;
|
|
300
319
|
}
|
|
301
320
|
|
|
302
|
-
|
|
321
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
322
|
+
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
|
323
|
+
|
|
324
|
+
char base[256];
|
|
325
|
+
char name[256];
|
|
326
|
+
|
|
327
|
+
snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
|
|
328
|
+
snprintf(name, 256, "%s", base);
|
|
329
|
+
|
|
330
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
331
|
+
if (!res.pipeline) {
|
|
332
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
return res;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
339
|
+
GGML_ASSERT(op->op == GGML_OP_TRI);
|
|
340
|
+
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
|
341
|
+
|
|
342
|
+
char base[256];
|
|
343
|
+
char name[256];
|
|
344
|
+
|
|
345
|
+
const char * op_str = "tri";
|
|
346
|
+
const int ttype = op->op_params[0];
|
|
347
|
+
|
|
348
|
+
snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
|
|
349
|
+
|
|
350
|
+
snprintf(name, 256, "%s", base);
|
|
351
|
+
|
|
352
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
353
|
+
if (!res.pipeline) {
|
|
354
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
return res;
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
303
361
|
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
|
|
304
362
|
|
|
305
363
|
char base[256];
|
|
@@ -316,19 +374,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar
|
|
|
316
374
|
snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
|
|
317
375
|
snprintf(name, 256, "%s", base);
|
|
318
376
|
|
|
319
|
-
|
|
320
|
-
if (res) {
|
|
321
|
-
|
|
377
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
378
|
+
if (!res.pipeline) {
|
|
379
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
322
380
|
}
|
|
323
381
|
|
|
324
|
-
res =
|
|
325
|
-
|
|
326
|
-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
382
|
+
res.smem = 32*sizeof(float);
|
|
327
383
|
|
|
328
384
|
return res;
|
|
329
385
|
}
|
|
330
386
|
|
|
331
|
-
|
|
387
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
332
388
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
333
389
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
334
390
|
|
|
@@ -338,43 +394,82 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
|
|
|
338
394
|
char base[256];
|
|
339
395
|
char name[256];
|
|
340
396
|
|
|
341
|
-
|
|
342
|
-
snprintf(name, 256, "%s", base);
|
|
397
|
+
const char * suffix = "";
|
|
343
398
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
return res;
|
|
399
|
+
if (op->src[1]->ne[0] % 4 == 0) {
|
|
400
|
+
suffix = "_4";
|
|
347
401
|
}
|
|
348
402
|
|
|
349
|
-
|
|
403
|
+
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
|
|
404
|
+
snprintf(name, 256, "%s", base);
|
|
405
|
+
|
|
406
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
407
|
+
if (!res.pipeline) {
|
|
408
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
409
|
+
}
|
|
350
410
|
|
|
351
411
|
return res;
|
|
352
412
|
}
|
|
353
413
|
|
|
354
|
-
|
|
414
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
|
|
415
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
416
|
+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
417
|
+
|
|
418
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
419
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
|
420
|
+
|
|
355
421
|
char base[256];
|
|
356
422
|
char name[256];
|
|
357
423
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
|
|
424
|
+
const char * suffix = "";
|
|
425
|
+
if (op->src[1]->ne[0] % 4 == 0) {
|
|
426
|
+
suffix = "_4";
|
|
362
427
|
}
|
|
363
|
-
snprintf(name, 256, "%s", base);
|
|
364
428
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
429
|
+
snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
|
|
430
|
+
snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
|
|
431
|
+
|
|
432
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
433
|
+
if (!res.pipeline) {
|
|
434
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
435
|
+
|
|
436
|
+
ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
|
|
437
|
+
|
|
438
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
439
|
+
|
|
440
|
+
ggml_metal_cv_free(cv);
|
|
368
441
|
}
|
|
369
442
|
|
|
370
|
-
res
|
|
443
|
+
return res;
|
|
444
|
+
}
|
|
371
445
|
|
|
372
|
-
|
|
446
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
447
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
448
|
+
|
|
449
|
+
char base[256];
|
|
450
|
+
char name[256];
|
|
451
|
+
|
|
452
|
+
const int nsg = (ne00 + 31)/32;
|
|
453
|
+
|
|
454
|
+
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
|
|
455
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
456
|
+
|
|
457
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
458
|
+
if (!res.pipeline) {
|
|
459
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
// Shared memory layout:
|
|
463
|
+
// - sgptg * NW floats for partial sums (nsg * 32)
|
|
464
|
+
// - sgptg floats for shared_x_dt (nsg)
|
|
465
|
+
// - sgptg floats for shared_dA (nsg)
|
|
466
|
+
// Total: nsg * (32 + 2) floats
|
|
467
|
+
res.smem = (32 + 2)*sizeof(float)*nsg;
|
|
373
468
|
|
|
374
469
|
return res;
|
|
375
470
|
}
|
|
376
471
|
|
|
377
|
-
|
|
472
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
378
473
|
char base[256];
|
|
379
474
|
char name[256];
|
|
380
475
|
|
|
@@ -404,41 +499,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
|
|
|
404
499
|
|
|
405
500
|
snprintf(name, 256, "%s", base);
|
|
406
501
|
|
|
407
|
-
|
|
408
|
-
if (res) {
|
|
409
|
-
|
|
502
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
503
|
+
if (!res.pipeline) {
|
|
504
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
410
505
|
}
|
|
411
506
|
|
|
412
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
413
|
-
|
|
414
507
|
return res;
|
|
415
508
|
}
|
|
416
509
|
|
|
417
|
-
|
|
510
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
|
|
418
511
|
char base[256];
|
|
419
512
|
char name[256];
|
|
420
513
|
|
|
421
514
|
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
|
|
422
515
|
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
|
|
423
516
|
|
|
424
|
-
|
|
425
|
-
if (res) {
|
|
426
|
-
|
|
427
|
-
}
|
|
517
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
518
|
+
if (!res.pipeline) {
|
|
519
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
428
520
|
|
|
429
|
-
|
|
521
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
522
|
+
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
|
|
430
523
|
|
|
431
|
-
|
|
432
|
-
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
|
|
524
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
433
525
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
ggml_metal_cv_free(cv);
|
|
526
|
+
ggml_metal_cv_free(cv);
|
|
527
|
+
}
|
|
437
528
|
|
|
438
529
|
return res;
|
|
439
530
|
}
|
|
440
531
|
|
|
441
|
-
|
|
532
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
442
533
|
char base[256];
|
|
443
534
|
char name[256];
|
|
444
535
|
|
|
@@ -451,27 +542,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_
|
|
|
451
542
|
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
|
452
543
|
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
|
|
453
544
|
|
|
454
|
-
|
|
455
|
-
if (res) {
|
|
456
|
-
|
|
457
|
-
}
|
|
458
|
-
|
|
459
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
545
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
546
|
+
if (!res.pipeline) {
|
|
547
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
460
548
|
|
|
461
|
-
|
|
462
|
-
|
|
549
|
+
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
|
550
|
+
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
|
|
463
551
|
|
|
464
|
-
|
|
552
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
465
553
|
|
|
466
|
-
|
|
554
|
+
ggml_metal_cv_free(cv);
|
|
555
|
+
}
|
|
467
556
|
|
|
468
557
|
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
|
|
469
|
-
|
|
558
|
+
res.smem = bc_out ? 8192 : 4096 + 2048;
|
|
470
559
|
|
|
471
560
|
return res;
|
|
472
561
|
}
|
|
473
562
|
|
|
474
|
-
|
|
563
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
475
564
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
476
565
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
477
566
|
|
|
@@ -626,49 +715,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
|
|
626
715
|
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
|
627
716
|
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
628
717
|
|
|
629
|
-
|
|
630
|
-
if (res) {
|
|
631
|
-
|
|
632
|
-
}
|
|
633
|
-
|
|
634
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
718
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
719
|
+
if (!res.pipeline) {
|
|
720
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
635
721
|
|
|
636
|
-
|
|
722
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
637
723
|
|
|
638
|
-
|
|
724
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
639
725
|
|
|
640
|
-
|
|
726
|
+
ggml_metal_cv_free(cv);
|
|
727
|
+
}
|
|
641
728
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
729
|
+
res.nr0 = nr0;
|
|
730
|
+
res.nr1 = nr1;
|
|
731
|
+
res.nsg = nsg;
|
|
732
|
+
res.smem = smem;
|
|
646
733
|
|
|
647
734
|
return res;
|
|
648
735
|
}
|
|
649
736
|
|
|
650
|
-
|
|
737
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
|
|
651
738
|
char base[256];
|
|
652
739
|
char name[256];
|
|
653
740
|
|
|
654
741
|
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
|
655
|
-
snprintf(name, 256, "%
|
|
742
|
+
snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
|
656
743
|
|
|
657
|
-
|
|
658
|
-
if (res) {
|
|
659
|
-
|
|
744
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
745
|
+
if (!res.pipeline) {
|
|
746
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
660
747
|
}
|
|
661
748
|
|
|
662
|
-
res =
|
|
663
|
-
|
|
664
|
-
const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t);
|
|
665
|
-
|
|
666
|
-
ggml_metal_pipeline_set_smem(res, smem);
|
|
749
|
+
res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
|
|
667
750
|
|
|
668
751
|
return res;
|
|
669
752
|
}
|
|
670
753
|
|
|
671
|
-
|
|
754
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
672
755
|
char base[256];
|
|
673
756
|
char name[256];
|
|
674
757
|
|
|
@@ -680,25 +763,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra
|
|
|
680
763
|
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
|
681
764
|
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
|
|
682
765
|
|
|
683
|
-
|
|
684
|
-
if (res) {
|
|
685
|
-
|
|
686
|
-
}
|
|
687
|
-
|
|
688
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
766
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
767
|
+
if (!res.pipeline) {
|
|
768
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
689
769
|
|
|
690
|
-
|
|
770
|
+
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
|
691
771
|
|
|
692
|
-
|
|
772
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
693
773
|
|
|
694
|
-
|
|
774
|
+
ggml_metal_cv_free(cv);
|
|
775
|
+
}
|
|
695
776
|
|
|
696
|
-
|
|
777
|
+
res.smem = 8192;
|
|
697
778
|
|
|
698
779
|
return res;
|
|
699
780
|
}
|
|
700
781
|
|
|
701
|
-
|
|
782
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
702
783
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
703
784
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
704
785
|
|
|
@@ -846,28 +927,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
|
|
846
927
|
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
|
847
928
|
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
848
929
|
|
|
849
|
-
|
|
850
|
-
if (res) {
|
|
851
|
-
|
|
852
|
-
}
|
|
853
|
-
|
|
854
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
930
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
931
|
+
if (!res.pipeline) {
|
|
932
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
855
933
|
|
|
856
|
-
|
|
934
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
|
857
935
|
|
|
858
|
-
|
|
936
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
859
937
|
|
|
860
|
-
|
|
938
|
+
ggml_metal_cv_free(cv);
|
|
939
|
+
}
|
|
861
940
|
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
941
|
+
res.nr0 = nr0;
|
|
942
|
+
res.nr1 = nr1;
|
|
943
|
+
res.nsg = nsg;
|
|
944
|
+
res.smem = smem;
|
|
866
945
|
|
|
867
946
|
return res;
|
|
868
947
|
}
|
|
869
948
|
|
|
870
|
-
|
|
949
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
871
950
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
872
951
|
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
|
873
952
|
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
|
@@ -878,19 +957,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_
|
|
|
878
957
|
snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
|
|
879
958
|
snprintf(name, 256, "%s", base);
|
|
880
959
|
|
|
881
|
-
|
|
882
|
-
if (res) {
|
|
883
|
-
|
|
960
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
961
|
+
if (!res.pipeline) {
|
|
962
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
884
963
|
}
|
|
885
964
|
|
|
886
|
-
res =
|
|
965
|
+
res.smem = 32*(sizeof(float) + sizeof(int32_t));
|
|
887
966
|
|
|
888
|
-
|
|
967
|
+
return res;
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
971
|
+
assert(op->op == GGML_OP_ARGSORT);
|
|
972
|
+
|
|
973
|
+
char base[256];
|
|
974
|
+
char name[256];
|
|
975
|
+
|
|
976
|
+
ggml_sort_order order = (ggml_sort_order) op->op_params[0];
|
|
977
|
+
|
|
978
|
+
const char * order_str = "undefined";
|
|
979
|
+
switch (order) {
|
|
980
|
+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
981
|
+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
982
|
+
default: GGML_ABORT("fatal error");
|
|
983
|
+
};
|
|
984
|
+
|
|
985
|
+
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
|
986
|
+
snprintf(name, 256, "%s", base);
|
|
987
|
+
|
|
988
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
989
|
+
if (!res.pipeline) {
|
|
990
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
991
|
+
}
|
|
889
992
|
|
|
890
993
|
return res;
|
|
891
994
|
}
|
|
892
995
|
|
|
893
|
-
|
|
996
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
894
997
|
assert(op->op == GGML_OP_ARGSORT);
|
|
895
998
|
|
|
896
999
|
char base[256];
|
|
@@ -905,26 +1008,165 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
|
|
905
1008
|
default: GGML_ABORT("fatal error");
|
|
906
1009
|
};
|
|
907
1010
|
|
|
1011
|
+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
|
1012
|
+
snprintf(name, 256, "%s", base);
|
|
1013
|
+
|
|
1014
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1015
|
+
if (!res.pipeline) {
|
|
1016
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1017
|
+
}
|
|
1018
|
+
|
|
1019
|
+
return res;
|
|
1020
|
+
}
|
|
1021
|
+
|
|
1022
|
+
// note: reuse the argsort kernel for top_k
|
|
1023
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1024
|
+
assert(op->op == GGML_OP_TOP_K);
|
|
1025
|
+
|
|
1026
|
+
char base[256];
|
|
1027
|
+
char name[256];
|
|
1028
|
+
|
|
1029
|
+
// note: the top_k kernel is always descending order
|
|
1030
|
+
ggml_sort_order order = GGML_SORT_ORDER_DESC;
|
|
1031
|
+
|
|
1032
|
+
const char * order_str = "undefined";
|
|
1033
|
+
switch (order) {
|
|
1034
|
+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
1035
|
+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
1036
|
+
default: GGML_ABORT("fatal error");
|
|
1037
|
+
};
|
|
1038
|
+
|
|
908
1039
|
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
|
909
1040
|
snprintf(name, 256, "%s", base);
|
|
910
1041
|
|
|
911
|
-
|
|
912
|
-
if (res) {
|
|
913
|
-
|
|
1042
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1043
|
+
if (!res.pipeline) {
|
|
1044
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
914
1045
|
}
|
|
915
1046
|
|
|
916
|
-
res
|
|
1047
|
+
return res;
|
|
1048
|
+
}
|
|
1049
|
+
|
|
1050
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1051
|
+
assert(op->op == GGML_OP_TOP_K);
|
|
1052
|
+
|
|
1053
|
+
char base[256];
|
|
1054
|
+
char name[256];
|
|
1055
|
+
|
|
1056
|
+
ggml_sort_order order = GGML_SORT_ORDER_DESC;
|
|
1057
|
+
|
|
1058
|
+
const char * order_str = "undefined";
|
|
1059
|
+
switch (order) {
|
|
1060
|
+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
1061
|
+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
1062
|
+
default: GGML_ABORT("fatal error");
|
|
1063
|
+
};
|
|
1064
|
+
|
|
1065
|
+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
|
1066
|
+
snprintf(name, 256, "%s", base);
|
|
1067
|
+
|
|
1068
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1069
|
+
if (!res.pipeline) {
|
|
1070
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1071
|
+
}
|
|
1072
|
+
|
|
1073
|
+
return res;
|
|
1074
|
+
}
|
|
1075
|
+
|
|
1076
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
1077
|
+
ggml_metal_library_t lib,
|
|
1078
|
+
const struct ggml_tensor * op,
|
|
1079
|
+
bool has_mask,
|
|
1080
|
+
int32_t ncpsg) {
|
|
1081
|
+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
1082
|
+
GGML_UNUSED(op);
|
|
1083
|
+
|
|
1084
|
+
char base[256];
|
|
1085
|
+
char name[256];
|
|
1086
|
+
|
|
1087
|
+
snprintf(base, 256, "kernel_%s",
|
|
1088
|
+
"flash_attn_ext_pad");
|
|
1089
|
+
|
|
1090
|
+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
|
1091
|
+
base,
|
|
1092
|
+
has_mask,
|
|
1093
|
+
ncpsg);
|
|
1094
|
+
|
|
1095
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1096
|
+
if (!res.pipeline) {
|
|
1097
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1098
|
+
|
|
1099
|
+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
|
1100
|
+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
|
1101
|
+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
|
1102
|
+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
|
1103
|
+
|
|
1104
|
+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
|
1105
|
+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
|
1106
|
+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
|
1107
|
+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
|
1108
|
+
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
|
|
1109
|
+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
|
|
1110
|
+
|
|
1111
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1112
|
+
|
|
1113
|
+
ggml_metal_cv_free(cv);
|
|
1114
|
+
}
|
|
1115
|
+
|
|
1116
|
+
return res;
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
|
1120
|
+
ggml_metal_library_t lib,
|
|
1121
|
+
const struct ggml_tensor * op,
|
|
1122
|
+
int32_t nqptg,
|
|
1123
|
+
int32_t ncpsg) {
|
|
1124
|
+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
1125
|
+
GGML_UNUSED(op);
|
|
1126
|
+
|
|
1127
|
+
char base[256];
|
|
1128
|
+
char name[256];
|
|
1129
|
+
|
|
1130
|
+
snprintf(base, 256, "kernel_%s",
|
|
1131
|
+
"flash_attn_ext_blk");
|
|
1132
|
+
|
|
1133
|
+
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
|
|
1134
|
+
base,
|
|
1135
|
+
nqptg,
|
|
1136
|
+
ncpsg);
|
|
1137
|
+
|
|
1138
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1139
|
+
if (!res.pipeline) {
|
|
1140
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1141
|
+
|
|
1142
|
+
//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
|
|
1143
|
+
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
|
|
1144
|
+
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
|
|
1145
|
+
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
|
|
1146
|
+
|
|
1147
|
+
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
|
|
1148
|
+
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
|
|
1149
|
+
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
|
|
1150
|
+
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
|
|
1151
|
+
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
|
|
1152
|
+
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
|
|
1153
|
+
|
|
1154
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1155
|
+
|
|
1156
|
+
ggml_metal_cv_free(cv);
|
|
1157
|
+
}
|
|
917
1158
|
|
|
918
1159
|
return res;
|
|
919
1160
|
}
|
|
920
1161
|
|
|
921
|
-
|
|
1162
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
922
1163
|
ggml_metal_library_t lib,
|
|
923
1164
|
const ggml_tensor * op,
|
|
924
1165
|
bool has_mask,
|
|
925
1166
|
bool has_sinks,
|
|
926
1167
|
bool has_bias,
|
|
927
1168
|
bool has_scap,
|
|
1169
|
+
bool has_kvpad,
|
|
928
1170
|
int32_t nsg) {
|
|
929
1171
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
930
1172
|
|
|
@@ -937,52 +1179,59 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
937
1179
|
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
|
938
1180
|
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
|
939
1181
|
|
|
1182
|
+
// do bounds checks for the mask?
|
|
1183
|
+
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
|
1184
|
+
|
|
940
1185
|
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
|
941
1186
|
"flash_attn_ext",
|
|
942
1187
|
ggml_type_name(op->src[1]->type),
|
|
943
1188
|
dk,
|
|
944
1189
|
dv);
|
|
945
1190
|
|
|
946
|
-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
1191
|
+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
947
1192
|
base,
|
|
948
1193
|
has_mask,
|
|
949
1194
|
has_sinks,
|
|
950
1195
|
has_bias,
|
|
951
1196
|
has_scap,
|
|
1197
|
+
has_kvpad,
|
|
1198
|
+
bc_mask,
|
|
952
1199
|
ns10,
|
|
953
1200
|
ns20,
|
|
954
1201
|
nsg);
|
|
955
1202
|
|
|
956
|
-
|
|
957
|
-
if (res) {
|
|
958
|
-
|
|
959
|
-
}
|
|
1203
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1204
|
+
if (!res.pipeline) {
|
|
1205
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
960
1206
|
|
|
961
|
-
|
|
1207
|
+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
|
|
1208
|
+
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
|
1209
|
+
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
|
1210
|
+
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
|
1211
|
+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
|
962
1212
|
|
|
963
|
-
|
|
964
|
-
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
|
965
|
-
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
|
966
|
-
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
|
1213
|
+
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
|
967
1214
|
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
1215
|
+
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
|
1216
|
+
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
|
1217
|
+
ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
|
|
971
1218
|
|
|
972
|
-
|
|
1219
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
973
1220
|
|
|
974
|
-
|
|
1221
|
+
ggml_metal_cv_free(cv);
|
|
1222
|
+
}
|
|
975
1223
|
|
|
976
1224
|
return res;
|
|
977
1225
|
}
|
|
978
1226
|
|
|
979
|
-
|
|
1227
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|
980
1228
|
ggml_metal_library_t lib,
|
|
981
1229
|
const ggml_tensor * op,
|
|
982
1230
|
bool has_mask,
|
|
983
1231
|
bool has_sinks,
|
|
984
1232
|
bool has_bias,
|
|
985
1233
|
bool has_scap,
|
|
1234
|
+
bool has_kvpad,
|
|
986
1235
|
int32_t nsg,
|
|
987
1236
|
int32_t nwg) {
|
|
988
1237
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
@@ -1002,41 +1251,41 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|
|
1002
1251
|
dk,
|
|
1003
1252
|
dv);
|
|
1004
1253
|
|
|
1005
|
-
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%
|
|
1254
|
+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
|
1006
1255
|
base,
|
|
1007
1256
|
has_mask,
|
|
1008
1257
|
has_sinks,
|
|
1009
1258
|
has_bias,
|
|
1010
1259
|
has_scap,
|
|
1260
|
+
has_kvpad,
|
|
1011
1261
|
ns10,
|
|
1012
1262
|
ns20,
|
|
1013
1263
|
nsg, nwg);
|
|
1014
1264
|
|
|
1015
|
-
|
|
1016
|
-
if (res) {
|
|
1017
|
-
|
|
1018
|
-
}
|
|
1019
|
-
|
|
1020
|
-
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1265
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1266
|
+
if (!res.pipeline) {
|
|
1267
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1021
1268
|
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1269
|
+
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
|
|
1270
|
+
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
|
1271
|
+
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
|
1272
|
+
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
|
1273
|
+
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
|
1026
1274
|
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1275
|
+
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
|
1276
|
+
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
|
1277
|
+
ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
|
|
1278
|
+
ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
|
|
1031
1279
|
|
|
1032
|
-
|
|
1280
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1033
1281
|
|
|
1034
|
-
|
|
1282
|
+
ggml_metal_cv_free(cv);
|
|
1283
|
+
}
|
|
1035
1284
|
|
|
1036
1285
|
return res;
|
|
1037
1286
|
}
|
|
1038
1287
|
|
|
1039
|
-
|
|
1288
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
|
1040
1289
|
ggml_metal_library_t lib,
|
|
1041
1290
|
const ggml_tensor * op,
|
|
1042
1291
|
int32_t dv,
|
|
@@ -1049,26 +1298,24 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
|
|
1049
1298
|
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
|
|
1050
1299
|
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
|
|
1051
1300
|
|
|
1052
|
-
|
|
1053
|
-
if (res) {
|
|
1054
|
-
|
|
1055
|
-
}
|
|
1301
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1302
|
+
if (!res.pipeline) {
|
|
1303
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1056
1304
|
|
|
1057
|
-
|
|
1305
|
+
ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
|
|
1306
|
+
ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
|
|
1058
1307
|
|
|
1059
|
-
|
|
1060
|
-
ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
|
|
1308
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1061
1309
|
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
ggml_metal_cv_free(cv);
|
|
1310
|
+
ggml_metal_cv_free(cv);
|
|
1311
|
+
}
|
|
1065
1312
|
|
|
1066
1313
|
return res;
|
|
1067
1314
|
|
|
1068
1315
|
GGML_UNUSED(op);
|
|
1069
1316
|
}
|
|
1070
1317
|
|
|
1071
|
-
|
|
1318
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
|
|
1072
1319
|
ggml_metal_library_t lib,
|
|
1073
1320
|
ggml_op op,
|
|
1074
1321
|
int32_t n_fuse,
|
|
@@ -1093,17 +1340,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
|
|
|
1093
1340
|
|
|
1094
1341
|
snprintf(name, 256, "%s", base);
|
|
1095
1342
|
|
|
1096
|
-
|
|
1097
|
-
if (res) {
|
|
1098
|
-
|
|
1343
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1344
|
+
if (!res.pipeline) {
|
|
1345
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1099
1346
|
}
|
|
1100
1347
|
|
|
1101
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1102
|
-
|
|
1103
1348
|
return res;
|
|
1104
1349
|
}
|
|
1105
1350
|
|
|
1106
|
-
|
|
1351
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1107
1352
|
assert(op->op == GGML_OP_L2_NORM);
|
|
1108
1353
|
|
|
1109
1354
|
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
|
@@ -1115,19 +1360,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library
|
|
|
1115
1360
|
snprintf(base, 256, "kernel_l2_norm_f32");
|
|
1116
1361
|
snprintf(name, 256, "%s", base);
|
|
1117
1362
|
|
|
1118
|
-
|
|
1119
|
-
if (res) {
|
|
1120
|
-
|
|
1363
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1364
|
+
if (!res.pipeline) {
|
|
1365
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1121
1366
|
}
|
|
1122
1367
|
|
|
1123
|
-
res =
|
|
1124
|
-
|
|
1125
|
-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1368
|
+
res.smem = 32*sizeof(float);
|
|
1126
1369
|
|
|
1127
1370
|
return res;
|
|
1128
1371
|
}
|
|
1129
1372
|
|
|
1130
|
-
|
|
1373
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1131
1374
|
assert(op->op == GGML_OP_GROUP_NORM);
|
|
1132
1375
|
|
|
1133
1376
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
@@ -1138,19 +1381,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
|
|
|
1138
1381
|
snprintf(base, 256, "kernel_group_norm_f32");
|
|
1139
1382
|
snprintf(name, 256, "%s", base);
|
|
1140
1383
|
|
|
1141
|
-
|
|
1142
|
-
if (res) {
|
|
1143
|
-
|
|
1384
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1385
|
+
if (!res.pipeline) {
|
|
1386
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1144
1387
|
}
|
|
1145
1388
|
|
|
1146
|
-
res =
|
|
1147
|
-
|
|
1148
|
-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1389
|
+
res.smem = 32*sizeof(float);
|
|
1149
1390
|
|
|
1150
1391
|
return res;
|
|
1151
1392
|
}
|
|
1152
1393
|
|
|
1153
|
-
|
|
1394
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
|
|
1154
1395
|
assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
|
|
1155
1396
|
|
|
1156
1397
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
@@ -1183,19 +1424,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t
|
|
|
1183
1424
|
|
|
1184
1425
|
snprintf(name, 256, "%s", base);
|
|
1185
1426
|
|
|
1186
|
-
|
|
1187
|
-
if (res) {
|
|
1188
|
-
|
|
1427
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1428
|
+
if (!res.pipeline) {
|
|
1429
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1189
1430
|
}
|
|
1190
1431
|
|
|
1191
|
-
res =
|
|
1192
|
-
|
|
1193
|
-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
1432
|
+
res.smem = 32*sizeof(float);
|
|
1194
1433
|
|
|
1195
1434
|
return res;
|
|
1196
1435
|
}
|
|
1197
1436
|
|
|
1198
|
-
|
|
1437
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1199
1438
|
assert(op->op == GGML_OP_ROPE);
|
|
1200
1439
|
|
|
1201
1440
|
char base[256];
|
|
@@ -1205,11 +1444,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
|
|
|
1205
1444
|
|
|
1206
1445
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
1207
1446
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
1447
|
+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
|
1208
1448
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
1209
1449
|
|
|
1210
1450
|
if (is_neox) {
|
|
1211
1451
|
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
|
|
1212
|
-
} else if (is_mrope && !is_vision) {
|
|
1452
|
+
} else if ((is_mrope || is_imrope) && !is_vision) {
|
|
1213
1453
|
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
|
1214
1454
|
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
|
|
1215
1455
|
} else if (is_vision) {
|
|
@@ -1219,19 +1459,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
|
|
|
1219
1459
|
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
|
|
1220
1460
|
}
|
|
1221
1461
|
|
|
1222
|
-
snprintf(name, 256, "%
|
|
1462
|
+
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
|
1223
1463
|
|
|
1224
|
-
|
|
1225
|
-
if (res) {
|
|
1226
|
-
|
|
1227
|
-
}
|
|
1464
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1465
|
+
if (!res.pipeline) {
|
|
1466
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1228
1467
|
|
|
1229
|
-
|
|
1468
|
+
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
|
1469
|
+
|
|
1470
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1471
|
+
|
|
1472
|
+
ggml_metal_cv_free(cv);
|
|
1473
|
+
}
|
|
1230
1474
|
|
|
1231
1475
|
return res;
|
|
1232
1476
|
}
|
|
1233
1477
|
|
|
1234
|
-
|
|
1478
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1235
1479
|
assert(op->op == GGML_OP_IM2COL);
|
|
1236
1480
|
|
|
1237
1481
|
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
|
@@ -1244,17 +1488,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
|
|
|
1244
1488
|
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
|
|
1245
1489
|
snprintf(name, 256, "%s", base);
|
|
1246
1490
|
|
|
1247
|
-
|
|
1248
|
-
if (res) {
|
|
1249
|
-
|
|
1491
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1492
|
+
if (!res.pipeline) {
|
|
1493
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1250
1494
|
}
|
|
1251
1495
|
|
|
1252
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1253
|
-
|
|
1254
1496
|
return res;
|
|
1255
1497
|
}
|
|
1256
1498
|
|
|
1257
|
-
|
|
1499
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1258
1500
|
assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
|
|
1259
1501
|
|
|
1260
1502
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
@@ -1269,17 +1511,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
|
|
|
1269
1511
|
snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
|
1270
1512
|
snprintf(name, 256, "%s", base);
|
|
1271
1513
|
|
|
1272
|
-
|
|
1273
|
-
if (res) {
|
|
1274
|
-
|
|
1514
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1515
|
+
if (!res.pipeline) {
|
|
1516
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1275
1517
|
}
|
|
1276
1518
|
|
|
1277
|
-
res
|
|
1519
|
+
return res;
|
|
1520
|
+
}
|
|
1521
|
+
|
|
1522
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1523
|
+
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
|
|
1524
|
+
|
|
1525
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
1526
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
|
1527
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
|
1528
|
+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
1529
|
+
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
1530
|
+
|
|
1531
|
+
char base[256];
|
|
1532
|
+
char name[256];
|
|
1533
|
+
|
|
1534
|
+
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
|
1535
|
+
snprintf(name, 256, "%s", base);
|
|
1536
|
+
|
|
1537
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1538
|
+
if (!res.pipeline) {
|
|
1539
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1540
|
+
}
|
|
1541
|
+
|
|
1542
|
+
return res;
|
|
1543
|
+
}
|
|
1544
|
+
|
|
1545
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1546
|
+
assert(op->op == GGML_OP_CONV_2D);
|
|
1547
|
+
|
|
1548
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
1549
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
|
1550
|
+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
1551
|
+
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
1552
|
+
|
|
1553
|
+
char base[256];
|
|
1554
|
+
char name[256];
|
|
1555
|
+
|
|
1556
|
+
snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
|
1557
|
+
snprintf(name, 256, "%s", base);
|
|
1558
|
+
|
|
1559
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1560
|
+
if (!res.pipeline) {
|
|
1561
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1562
|
+
}
|
|
1278
1563
|
|
|
1279
1564
|
return res;
|
|
1280
1565
|
}
|
|
1281
1566
|
|
|
1282
|
-
|
|
1567
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1283
1568
|
assert(op->op == GGML_OP_UPSCALE);
|
|
1284
1569
|
|
|
1285
1570
|
char base[256];
|
|
@@ -1288,17 +1573,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library
|
|
|
1288
1573
|
snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
|
|
1289
1574
|
snprintf(name, 256, "%s", base);
|
|
1290
1575
|
|
|
1291
|
-
|
|
1292
|
-
if (res) {
|
|
1293
|
-
|
|
1576
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1577
|
+
if (!res.pipeline) {
|
|
1578
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1294
1579
|
}
|
|
1295
1580
|
|
|
1296
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1297
|
-
|
|
1298
1581
|
return res;
|
|
1299
1582
|
}
|
|
1300
1583
|
|
|
1301
|
-
|
|
1584
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1302
1585
|
assert(op->op == GGML_OP_PAD);
|
|
1303
1586
|
|
|
1304
1587
|
char base[256];
|
|
@@ -1307,8 +1590,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
|
|
|
1307
1590
|
snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
|
|
1308
1591
|
snprintf(name, 256, "%s", base);
|
|
1309
1592
|
|
|
1310
|
-
|
|
1311
|
-
if (res) {
|
|
1593
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1594
|
+
if (res.pipeline) {
|
|
1312
1595
|
return res;
|
|
1313
1596
|
}
|
|
1314
1597
|
|
|
@@ -1317,7 +1600,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
|
|
|
1317
1600
|
return res;
|
|
1318
1601
|
}
|
|
1319
1602
|
|
|
1320
|
-
|
|
1603
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1321
1604
|
assert(op->op == GGML_OP_PAD_REFLECT_1D);
|
|
1322
1605
|
|
|
1323
1606
|
char base[256];
|
|
@@ -1326,17 +1609,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_
|
|
|
1326
1609
|
snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
|
|
1327
1610
|
snprintf(name, 256, "%s", base);
|
|
1328
1611
|
|
|
1329
|
-
|
|
1330
|
-
if (res) {
|
|
1331
|
-
|
|
1612
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1613
|
+
if (!res.pipeline) {
|
|
1614
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1332
1615
|
}
|
|
1333
1616
|
|
|
1334
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1335
|
-
|
|
1336
1617
|
return res;
|
|
1337
1618
|
}
|
|
1338
1619
|
|
|
1339
|
-
|
|
1620
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1340
1621
|
assert(op->op == GGML_OP_ARANGE);
|
|
1341
1622
|
|
|
1342
1623
|
char base[256];
|
|
@@ -1345,17 +1626,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_
|
|
|
1345
1626
|
snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
|
|
1346
1627
|
snprintf(name, 256, "%s", base);
|
|
1347
1628
|
|
|
1348
|
-
|
|
1349
|
-
if (res) {
|
|
1350
|
-
|
|
1629
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1630
|
+
if (!res.pipeline) {
|
|
1631
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1351
1632
|
}
|
|
1352
1633
|
|
|
1353
|
-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1354
|
-
|
|
1355
1634
|
return res;
|
|
1356
1635
|
}
|
|
1357
1636
|
|
|
1358
|
-
|
|
1637
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1359
1638
|
assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
|
|
1360
1639
|
|
|
1361
1640
|
char base[256];
|
|
@@ -1364,13 +1643,101 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
|
|
|
1364
1643
|
snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
|
|
1365
1644
|
snprintf(name, 256, "%s", base);
|
|
1366
1645
|
|
|
1367
|
-
|
|
1368
|
-
if (res) {
|
|
1369
|
-
|
|
1646
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1647
|
+
if (!res.pipeline) {
|
|
1648
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1370
1649
|
}
|
|
1371
1650
|
|
|
1372
|
-
res
|
|
1651
|
+
return res;
|
|
1652
|
+
}
|
|
1653
|
+
|
|
1654
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1655
|
+
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
|
|
1656
|
+
|
|
1657
|
+
char base[256];
|
|
1658
|
+
char name[256];
|
|
1659
|
+
|
|
1660
|
+
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
|
|
1661
|
+
snprintf(name, 256, "%s", base);
|
|
1662
|
+
|
|
1663
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1664
|
+
if (!res.pipeline) {
|
|
1665
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1666
|
+
}
|
|
1667
|
+
|
|
1668
|
+
return res;
|
|
1669
|
+
}
|
|
1670
|
+
|
|
1671
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1672
|
+
assert(op->op == GGML_OP_OPT_STEP_SGD);
|
|
1673
|
+
|
|
1674
|
+
char base[256];
|
|
1675
|
+
char name[256];
|
|
1676
|
+
|
|
1677
|
+
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
|
|
1678
|
+
snprintf(name, 256, "%s", base);
|
|
1679
|
+
|
|
1680
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1681
|
+
if (!res.pipeline) {
|
|
1682
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1683
|
+
}
|
|
1684
|
+
|
|
1685
|
+
return res;
|
|
1686
|
+
}
|
|
1687
|
+
|
|
1688
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1689
|
+
GGML_ASSERT(op->type == GGML_TYPE_I64);
|
|
1690
|
+
|
|
1691
|
+
char base[256];
|
|
1692
|
+
char name[256];
|
|
1693
|
+
|
|
1694
|
+
snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
|
|
1695
|
+
snprintf(name, 256, "%s", base);
|
|
1696
|
+
|
|
1697
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1698
|
+
if (!res.pipeline) {
|
|
1699
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1700
|
+
}
|
|
1373
1701
|
|
|
1374
1702
|
return res;
|
|
1375
1703
|
}
|
|
1376
1704
|
|
|
1705
|
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
|
|
1706
|
+
assert(op->op == GGML_OP_COUNT_EQUAL);
|
|
1707
|
+
|
|
1708
|
+
GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
|
|
1709
|
+
|
|
1710
|
+
GGML_ASSERT(op->src[0]->type == op->src[1]->type);
|
|
1711
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
|
|
1712
|
+
GGML_ASSERT(op->type == GGML_TYPE_I64);
|
|
1713
|
+
|
|
1714
|
+
// note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
|
|
1715
|
+
GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
|
|
1716
|
+
|
|
1717
|
+
char base[256];
|
|
1718
|
+
char name[256];
|
|
1719
|
+
|
|
1720
|
+
int nsg = 1;
|
|
1721
|
+
while (32*nsg < ne00 && nsg < 32) {
|
|
1722
|
+
nsg *= 2;
|
|
1723
|
+
}
|
|
1724
|
+
|
|
1725
|
+
snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
|
|
1726
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
1727
|
+
|
|
1728
|
+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
|
1729
|
+
if (!res.pipeline) {
|
|
1730
|
+
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
|
1731
|
+
|
|
1732
|
+
ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
|
|
1733
|
+
|
|
1734
|
+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1735
|
+
|
|
1736
|
+
ggml_metal_cv_free(cv);
|
|
1737
|
+
}
|
|
1738
|
+
|
|
1739
|
+
res.smem = 32 * sizeof(int32_t);
|
|
1740
|
+
res.nsg = nsg;
|
|
1741
|
+
|
|
1742
|
+
return res;
|
|
1743
|
+
}
|