whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -10,6 +10,8 @@
|
|
|
10
10
|
|
|
11
11
|
#include <cassert>
|
|
12
12
|
#include <algorithm>
|
|
13
|
+
#include <limits>
|
|
14
|
+
#include <cmath>
|
|
13
15
|
|
|
14
16
|
static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
|
|
15
17
|
if (!t) {
|
|
@@ -219,13 +221,17 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
219
221
|
}
|
|
220
222
|
|
|
221
223
|
if (ctx->debug_graph > 0) {
|
|
222
|
-
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : "");
|
|
224
|
+
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
|
|
223
225
|
}
|
|
224
226
|
if (ctx->debug_graph > 1) {
|
|
225
227
|
GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
|
|
226
228
|
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
|
227
229
|
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
|
228
230
|
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
|
231
|
+
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
|
232
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
|
233
|
+
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
|
234
|
+
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
|
229
235
|
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
|
230
236
|
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
|
231
237
|
|
|
@@ -237,6 +243,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
237
243
|
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
238
244
|
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
|
239
245
|
}
|
|
246
|
+
if (node->src[2]) {
|
|
247
|
+
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
|
248
|
+
ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
|
249
|
+
}
|
|
250
|
+
if (node->src[3]) {
|
|
251
|
+
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
|
252
|
+
ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
|
253
|
+
}
|
|
240
254
|
if (node) {
|
|
241
255
|
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
|
242
256
|
node->name);
|
|
@@ -272,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
272
286
|
{
|
|
273
287
|
n_fuse = ggml_metal_op_scale(ctx, idx);
|
|
274
288
|
} break;
|
|
289
|
+
case GGML_OP_FILL:
|
|
290
|
+
{
|
|
291
|
+
n_fuse = ggml_metal_op_fill(ctx, idx);
|
|
292
|
+
} break;
|
|
275
293
|
case GGML_OP_CLAMP:
|
|
276
294
|
{
|
|
277
295
|
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
|
@@ -289,11 +307,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
289
307
|
{
|
|
290
308
|
n_fuse = ggml_metal_op_glu(ctx, idx);
|
|
291
309
|
} break;
|
|
310
|
+
case GGML_OP_SUM:
|
|
311
|
+
{
|
|
312
|
+
n_fuse = ggml_metal_op_sum(ctx, idx);
|
|
313
|
+
} break;
|
|
292
314
|
case GGML_OP_SUM_ROWS:
|
|
293
315
|
case GGML_OP_MEAN:
|
|
294
316
|
{
|
|
295
317
|
n_fuse = ggml_metal_op_sum_rows(ctx, idx);
|
|
296
318
|
} break;
|
|
319
|
+
case GGML_OP_CUMSUM:
|
|
320
|
+
{
|
|
321
|
+
n_fuse = ggml_metal_op_cumsum(ctx, idx);
|
|
322
|
+
} break;
|
|
297
323
|
case GGML_OP_SOFT_MAX:
|
|
298
324
|
{
|
|
299
325
|
n_fuse = ggml_metal_op_soft_max(ctx, idx);
|
|
@@ -348,10 +374,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
348
374
|
{
|
|
349
375
|
n_fuse = ggml_metal_op_im2col(ctx, idx);
|
|
350
376
|
} break;
|
|
377
|
+
case GGML_OP_CONV_2D:
|
|
378
|
+
{
|
|
379
|
+
n_fuse = ggml_metal_op_conv_2d(ctx, idx);
|
|
380
|
+
} break;
|
|
351
381
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
352
382
|
{
|
|
353
383
|
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
|
|
354
384
|
} break;
|
|
385
|
+
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
386
|
+
{
|
|
387
|
+
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
|
|
388
|
+
} break;
|
|
355
389
|
case GGML_OP_UPSCALE:
|
|
356
390
|
{
|
|
357
391
|
n_fuse = ggml_metal_op_upscale(ctx, idx);
|
|
@@ -376,10 +410,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
376
410
|
{
|
|
377
411
|
n_fuse = ggml_metal_op_argsort(ctx, idx);
|
|
378
412
|
} break;
|
|
413
|
+
case GGML_OP_TOP_K:
|
|
414
|
+
{
|
|
415
|
+
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
|
416
|
+
} break;
|
|
379
417
|
case GGML_OP_LEAKY_RELU:
|
|
380
418
|
{
|
|
381
419
|
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
|
382
420
|
} break;
|
|
421
|
+
case GGML_OP_TRI:
|
|
422
|
+
{
|
|
423
|
+
n_fuse = ggml_metal_op_tri(ctx, idx);
|
|
424
|
+
} break;
|
|
383
425
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
384
426
|
{
|
|
385
427
|
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
|
@@ -398,7 +440,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
398
440
|
{
|
|
399
441
|
n_fuse = ggml_metal_op_argmax(ctx, idx);
|
|
400
442
|
} break;
|
|
401
|
-
|
|
443
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
444
|
+
{
|
|
445
|
+
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
|
|
446
|
+
} break;
|
|
447
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
448
|
+
{
|
|
449
|
+
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
|
450
|
+
} break;
|
|
451
|
+
case GGML_OP_COUNT_EQUAL:
|
|
452
|
+
{
|
|
453
|
+
n_fuse = ggml_metal_op_count_equal(ctx, idx);
|
|
454
|
+
} break;
|
|
455
|
+
default:
|
|
402
456
|
{
|
|
403
457
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
|
404
458
|
GGML_ABORT("fatal error");
|
|
@@ -482,7 +536,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
|
|
482
536
|
/*.dim =*/ dim,
|
|
483
537
|
};
|
|
484
538
|
|
|
485
|
-
|
|
539
|
+
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
|
|
486
540
|
|
|
487
541
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
488
542
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -506,9 +560,9 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
|
|
506
560
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
507
561
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
508
562
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
509
|
-
GGML_TENSOR_LOCALS(
|
|
563
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
510
564
|
|
|
511
|
-
|
|
565
|
+
auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
|
512
566
|
|
|
513
567
|
ggml_metal_kargs_repeat args = {
|
|
514
568
|
/*.ne00 =*/ ne00,
|
|
@@ -552,7 +606,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
552
606
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
553
607
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
554
608
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
555
|
-
GGML_TENSOR_LOCALS(
|
|
609
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
556
610
|
|
|
557
611
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
558
612
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
@@ -574,9 +628,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
574
628
|
// TODO: make a simpler cpy_bytes kernel
|
|
575
629
|
|
|
576
630
|
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
|
577
|
-
|
|
631
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
578
632
|
|
|
579
633
|
ggml_metal_kargs_cpy args = {
|
|
634
|
+
/*.nk0 =*/ ne00,
|
|
580
635
|
/*.ne00 =*/ ne00,
|
|
581
636
|
/*.ne01 =*/ ne01,
|
|
582
637
|
/*.ne02 =*/ ne02,
|
|
@@ -636,7 +691,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|
|
636
691
|
/*.o1 =*/ { 0 },
|
|
637
692
|
};
|
|
638
693
|
|
|
639
|
-
|
|
694
|
+
auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
|
|
640
695
|
|
|
641
696
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
642
697
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -660,7 +715,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
|
|
660
715
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
661
716
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
662
717
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
663
|
-
GGML_TENSOR_LOCALS(
|
|
718
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
664
719
|
|
|
665
720
|
float scale;
|
|
666
721
|
float bias;
|
|
@@ -678,7 +733,42 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
|
|
678
733
|
n /= 4;
|
|
679
734
|
}
|
|
680
735
|
|
|
681
|
-
|
|
736
|
+
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
737
|
+
|
|
738
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
739
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
740
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
741
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
742
|
+
|
|
743
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
|
744
|
+
|
|
745
|
+
return 1;
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
|
|
749
|
+
ggml_tensor * op = ctx->node(idx);
|
|
750
|
+
|
|
751
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
752
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
753
|
+
|
|
754
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
755
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
756
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
757
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
758
|
+
|
|
759
|
+
const float val = ggml_get_op_params_f32(op, 0);
|
|
760
|
+
|
|
761
|
+
ggml_metal_kargs_fill args = {
|
|
762
|
+
/*.val =*/ val
|
|
763
|
+
};
|
|
764
|
+
|
|
765
|
+
int64_t n = ggml_nelements(op);
|
|
766
|
+
|
|
767
|
+
if (n % 4 == 0) {
|
|
768
|
+
n /= 4;
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
682
772
|
|
|
683
773
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
684
774
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -699,7 +789,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
|
|
699
789
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
700
790
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
701
791
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
702
|
-
GGML_TENSOR_LOCALS(
|
|
792
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
703
793
|
|
|
704
794
|
float min;
|
|
705
795
|
float max;
|
|
@@ -717,7 +807,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
|
|
717
807
|
n /= 4;
|
|
718
808
|
}
|
|
719
809
|
|
|
720
|
-
|
|
810
|
+
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
721
811
|
|
|
722
812
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
723
813
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -738,7 +828,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|
|
738
828
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
739
829
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
740
830
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
741
|
-
GGML_TENSOR_LOCALS(
|
|
831
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
742
832
|
|
|
743
833
|
int64_t n = ggml_nelements(op);
|
|
744
834
|
|
|
@@ -746,7 +836,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|
|
746
836
|
n /= 4;
|
|
747
837
|
}
|
|
748
838
|
|
|
749
|
-
|
|
839
|
+
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
750
840
|
|
|
751
841
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
752
842
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
|
@@ -768,13 +858,13 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
|
|
768
858
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
769
859
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
770
860
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
771
|
-
GGML_TENSOR_LOCALS(
|
|
861
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
772
862
|
|
|
773
863
|
if (op->src[1]) {
|
|
774
864
|
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
|
|
775
865
|
}
|
|
776
866
|
|
|
777
|
-
|
|
867
|
+
auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
|
|
778
868
|
|
|
779
869
|
const int32_t swp = ggml_get_op_params_i32(op, 1);
|
|
780
870
|
const float alpha = ggml_get_op_params_f32(op, 2);
|
|
@@ -800,18 +890,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
|
|
800
890
|
|
|
801
891
|
const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
|
|
802
892
|
|
|
803
|
-
//[encoder setComputePipelineState:pipeline];
|
|
804
|
-
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
805
|
-
//if (src1) {
|
|
806
|
-
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
807
|
-
//} else {
|
|
808
|
-
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
809
|
-
//}
|
|
810
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
811
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
812
|
-
|
|
813
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
814
|
-
|
|
815
893
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
816
894
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
817
895
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
@@ -827,6 +905,43 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
|
|
827
905
|
return 1;
|
|
828
906
|
}
|
|
829
907
|
|
|
908
|
+
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
|
909
|
+
ggml_tensor * op = ctx->node(idx);
|
|
910
|
+
|
|
911
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
912
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
913
|
+
|
|
914
|
+
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
|
|
915
|
+
|
|
916
|
+
ggml_metal_kargs_sum args = {
|
|
917
|
+
/*.np =*/ n,
|
|
918
|
+
};
|
|
919
|
+
|
|
920
|
+
auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
|
921
|
+
|
|
922
|
+
int nth = 32; // SIMD width
|
|
923
|
+
|
|
924
|
+
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
925
|
+
nth *= 2;
|
|
926
|
+
}
|
|
927
|
+
|
|
928
|
+
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
929
|
+
nth = std::min(nth, (int) n);
|
|
930
|
+
|
|
931
|
+
const int nsg = (nth + 31) / 32;
|
|
932
|
+
|
|
933
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
934
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
935
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
936
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
937
|
+
|
|
938
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
|
|
939
|
+
|
|
940
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
|
|
941
|
+
|
|
942
|
+
return 1;
|
|
943
|
+
}
|
|
944
|
+
|
|
830
945
|
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
831
946
|
ggml_tensor * op = ctx->node(idx);
|
|
832
947
|
|
|
@@ -836,7 +951,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
836
951
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
837
952
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
838
953
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
839
|
-
GGML_TENSOR_LOCALS(
|
|
954
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
840
955
|
|
|
841
956
|
ggml_metal_kargs_sum_rows args = {
|
|
842
957
|
/*.ne00 =*/ ne00,
|
|
@@ -857,7 +972,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
857
972
|
/*.nb3 =*/ nb3,
|
|
858
973
|
};
|
|
859
974
|
|
|
860
|
-
|
|
975
|
+
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
|
861
976
|
|
|
862
977
|
int nth = 32; // SIMD width
|
|
863
978
|
|
|
@@ -868,15 +983,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
868
983
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
869
984
|
nth = std::min(nth, ne00);
|
|
870
985
|
|
|
871
|
-
const size_t smem =
|
|
872
|
-
|
|
873
|
-
//[encoder setComputePipelineState:pipeline];
|
|
874
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
875
|
-
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
876
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
877
|
-
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
878
|
-
|
|
879
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
986
|
+
const size_t smem = pipeline.smem;
|
|
880
987
|
|
|
881
988
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
882
989
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -890,6 +997,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
890
997
|
return 1;
|
|
891
998
|
}
|
|
892
999
|
|
|
1000
|
+
int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
|
1001
|
+
ggml_tensor * op = ctx->node(idx);
|
|
1002
|
+
|
|
1003
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
1004
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
1005
|
+
|
|
1006
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
1007
|
+
|
|
1008
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1009
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1010
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1011
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1012
|
+
|
|
1013
|
+
auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
|
1014
|
+
|
|
1015
|
+
int nth = 1;
|
|
1016
|
+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
|
1017
|
+
nth *= 2;
|
|
1018
|
+
}
|
|
1019
|
+
|
|
1020
|
+
GGML_ASSERT(ne00 <= nth*nth);
|
|
1021
|
+
|
|
1022
|
+
const int64_t net0 = (ne00 + nth - 1) / nth;
|
|
1023
|
+
const int64_t net1 = ne01;
|
|
1024
|
+
const int64_t net2 = ne02;
|
|
1025
|
+
const int64_t net3 = ne03;
|
|
1026
|
+
|
|
1027
|
+
const uint64_t nbt0 = sizeof(float);
|
|
1028
|
+
const uint64_t nbt1 = net0*nbt0;
|
|
1029
|
+
const uint64_t nbt2 = net1*nbt1;
|
|
1030
|
+
const uint64_t nbt3 = net2*nbt2;
|
|
1031
|
+
|
|
1032
|
+
const size_t smem = GGML_PAD(32*sizeof(float), 16);
|
|
1033
|
+
|
|
1034
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
1035
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
1036
|
+
|
|
1037
|
+
ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
1038
|
+
bid_tmp.offs += ggml_nbytes(op);
|
|
1039
|
+
|
|
1040
|
+
{
|
|
1041
|
+
ggml_metal_kargs_cumsum_blk args = {
|
|
1042
|
+
/*.ne00 =*/ ne00,
|
|
1043
|
+
/*.ne01 =*/ ne01,
|
|
1044
|
+
/*.ne02 =*/ ne02,
|
|
1045
|
+
/*.ne03 =*/ ne03,
|
|
1046
|
+
/*.nb00 =*/ nb00,
|
|
1047
|
+
/*.nb01 =*/ nb01,
|
|
1048
|
+
/*.nb02 =*/ nb02,
|
|
1049
|
+
/*.nb03 =*/ nb03,
|
|
1050
|
+
/*.net0 =*/ net0,
|
|
1051
|
+
/*.net1 =*/ net1,
|
|
1052
|
+
/*.net2 =*/ net2,
|
|
1053
|
+
/*.net3 =*/ net3,
|
|
1054
|
+
/*.nbt0 =*/ nbt0,
|
|
1055
|
+
/*.nbt1 =*/ nbt1,
|
|
1056
|
+
/*.nbt2 =*/ nbt2,
|
|
1057
|
+
/*.nbt3 =*/ nbt3,
|
|
1058
|
+
/*.outb =*/ ne00 > nth,
|
|
1059
|
+
};
|
|
1060
|
+
|
|
1061
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
|
1062
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1063
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
1064
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
|
1065
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
|
1066
|
+
|
|
1067
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1068
|
+
|
|
1069
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
|
1070
|
+
}
|
|
1071
|
+
|
|
1072
|
+
if (ne00 > nth) {
|
|
1073
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
1074
|
+
|
|
1075
|
+
{
|
|
1076
|
+
ggml_metal_kargs_cumsum_blk args = {
|
|
1077
|
+
/*.ne00 =*/ net0,
|
|
1078
|
+
/*.ne01 =*/ net1,
|
|
1079
|
+
/*.ne02 =*/ net2,
|
|
1080
|
+
/*.ne03 =*/ net3,
|
|
1081
|
+
/*.nb00 =*/ nbt0,
|
|
1082
|
+
/*.nb01 =*/ nbt1,
|
|
1083
|
+
/*.nb02 =*/ nbt2,
|
|
1084
|
+
/*.nb03 =*/ nbt3,
|
|
1085
|
+
/*.net0 =*/ net0,
|
|
1086
|
+
/*.net1 =*/ net1,
|
|
1087
|
+
/*.net2 =*/ net2,
|
|
1088
|
+
/*.net3 =*/ net3,
|
|
1089
|
+
/*.nbt0 =*/ nbt0,
|
|
1090
|
+
/*.nbt1 =*/ nbt1,
|
|
1091
|
+
/*.nbt2 =*/ nbt2,
|
|
1092
|
+
/*.nbt3 =*/ nbt3,
|
|
1093
|
+
/*.outb =*/ false,
|
|
1094
|
+
};
|
|
1095
|
+
|
|
1096
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
|
1097
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1098
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
|
1099
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
|
1100
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
1101
|
+
|
|
1102
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1103
|
+
|
|
1104
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
|
|
1105
|
+
}
|
|
1106
|
+
|
|
1107
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
1108
|
+
|
|
1109
|
+
{
|
|
1110
|
+
auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
|
1111
|
+
|
|
1112
|
+
ggml_metal_kargs_cumsum_add args = {
|
|
1113
|
+
/*.ne00 =*/ ne00,
|
|
1114
|
+
/*.ne01 =*/ ne01,
|
|
1115
|
+
/*.ne02 =*/ ne02,
|
|
1116
|
+
/*.ne03 =*/ ne03,
|
|
1117
|
+
/*.nb00 =*/ nb00,
|
|
1118
|
+
/*.nb01 =*/ nb01,
|
|
1119
|
+
/*.nb02 =*/ nb02,
|
|
1120
|
+
/*.nb03 =*/ nb03,
|
|
1121
|
+
/*.net0 =*/ net0,
|
|
1122
|
+
/*.net1 =*/ net1,
|
|
1123
|
+
/*.net2 =*/ net2,
|
|
1124
|
+
/*.net3 =*/ net3,
|
|
1125
|
+
/*.nbt0 =*/ nbt0,
|
|
1126
|
+
/*.nbt1 =*/ nbt1,
|
|
1127
|
+
/*.nbt2 =*/ nbt2,
|
|
1128
|
+
/*.nbt3 =*/ nbt3,
|
|
1129
|
+
};
|
|
1130
|
+
|
|
1131
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_add);
|
|
1132
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1133
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
|
1134
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
1135
|
+
|
|
1136
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
|
1137
|
+
}
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
return 1;
|
|
1141
|
+
}
|
|
1142
|
+
|
|
893
1143
|
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
|
894
1144
|
ggml_tensor * op = ctx->node(idx);
|
|
895
1145
|
|
|
@@ -901,28 +1151,36 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
901
1151
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
902
1152
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
903
1153
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
904
|
-
GGML_TENSOR_LOCALS(
|
|
1154
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
905
1155
|
|
|
906
|
-
|
|
1156
|
+
auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
|
907
1157
|
|
|
908
1158
|
ggml_metal_kargs_get_rows args = {
|
|
909
|
-
/*.
|
|
910
|
-
/*.
|
|
911
|
-
/*.
|
|
912
|
-
/*.
|
|
913
|
-
/*.
|
|
914
|
-
/*.
|
|
915
|
-
/*.
|
|
916
|
-
/*.
|
|
1159
|
+
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
|
1160
|
+
/*.ne00 =*/ ne00,
|
|
1161
|
+
/*.nb01 =*/ nb01,
|
|
1162
|
+
/*.nb02 =*/ nb02,
|
|
1163
|
+
/*.nb03 =*/ nb03,
|
|
1164
|
+
/*.ne10 =*/ ne10,
|
|
1165
|
+
/*.nb10 =*/ nb10,
|
|
1166
|
+
/*.nb11 =*/ nb11,
|
|
1167
|
+
/*.nb12 =*/ nb12,
|
|
1168
|
+
/*.nb1 =*/ nb1,
|
|
1169
|
+
/*.nb2 =*/ nb2,
|
|
1170
|
+
/*.nb3 =*/ nb3,
|
|
917
1171
|
};
|
|
918
1172
|
|
|
1173
|
+
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1174
|
+
|
|
1175
|
+
const int nw0 = (args.ne00t + nth - 1)/nth;
|
|
1176
|
+
|
|
919
1177
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
920
1178
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
921
1179
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
922
1180
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
923
1181
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
924
1182
|
|
|
925
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12,
|
|
1183
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
|
|
926
1184
|
|
|
927
1185
|
return 1;
|
|
928
1186
|
}
|
|
@@ -938,9 +1196,9 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
|
|
938
1196
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
939
1197
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
940
1198
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
941
|
-
GGML_TENSOR_LOCALS(
|
|
1199
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
942
1200
|
|
|
943
|
-
|
|
1201
|
+
auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
|
944
1202
|
|
|
945
1203
|
const int32_t nk0 = ne0/ggml_blck_size(op->type);
|
|
946
1204
|
|
|
@@ -1002,7 +1260,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
|
1002
1260
|
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1003
1261
|
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1004
1262
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1005
|
-
GGML_TENSOR_LOCALS(
|
|
1263
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1006
1264
|
|
|
1007
1265
|
float scale;
|
|
1008
1266
|
float max_bias;
|
|
@@ -1041,7 +1299,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
|
1041
1299
|
/*.n_head_log2 =*/ n_head_log2,
|
|
1042
1300
|
};
|
|
1043
1301
|
|
|
1044
|
-
|
|
1302
|
+
auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
|
|
1045
1303
|
|
|
1046
1304
|
int nth = 32; // SIMD width
|
|
1047
1305
|
|
|
@@ -1055,7 +1313,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|
|
1055
1313
|
}
|
|
1056
1314
|
}
|
|
1057
1315
|
|
|
1058
|
-
const size_t smem =
|
|
1316
|
+
const size_t smem = pipeline.smem;
|
|
1059
1317
|
|
|
1060
1318
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1061
1319
|
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
@@ -1090,7 +1348,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
|
|
1090
1348
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1091
1349
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1092
1350
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1093
|
-
GGML_TENSOR_LOCALS(
|
|
1351
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1094
1352
|
|
|
1095
1353
|
ggml_metal_kargs_ssm_conv args = {
|
|
1096
1354
|
/*.ne00 =*/ ne00,
|
|
@@ -1111,15 +1369,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
|
|
1111
1369
|
/*.nb2 =*/ nb2,
|
|
1112
1370
|
};
|
|
1113
1371
|
|
|
1114
|
-
|
|
1372
|
+
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
|
|
1373
|
+
const bool use_batched = (ne1 > 1);
|
|
1115
1374
|
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1375
|
+
if (use_batched) {
|
|
1376
|
+
// Determine the smallest power of 2 that's >= ne1, but <= 256
|
|
1377
|
+
int BATCH_SIZE;
|
|
1378
|
+
if (ne1 > 128) BATCH_SIZE = 256;
|
|
1379
|
+
else if (ne1 > 64 ) BATCH_SIZE = 128;
|
|
1380
|
+
else if (ne1 > 32 ) BATCH_SIZE = 64;
|
|
1381
|
+
else if (ne1 > 16 ) BATCH_SIZE = 32;
|
|
1382
|
+
else if (ne1 > 8 ) BATCH_SIZE = 16;
|
|
1383
|
+
else if (ne1 > 4 ) BATCH_SIZE = 8;
|
|
1384
|
+
else BATCH_SIZE = 2;
|
|
1385
|
+
|
|
1386
|
+
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
|
|
1387
|
+
|
|
1388
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1389
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1390
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1391
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1392
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
|
1393
|
+
|
|
1394
|
+
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
|
|
1395
|
+
// Each threadgroup has BATCH_SIZE threads, each handling one token
|
|
1396
|
+
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
|
|
1397
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
|
|
1398
|
+
} else {
|
|
1399
|
+
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
|
1400
|
+
|
|
1401
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1402
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
1403
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1404
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1405
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
|
1121
1406
|
|
|
1122
|
-
|
|
1407
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
|
1408
|
+
}
|
|
1123
1409
|
|
|
1124
1410
|
return 1;
|
|
1125
1411
|
}
|
|
@@ -1145,7 +1431,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|
|
1145
1431
|
GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
|
|
1146
1432
|
GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
|
|
1147
1433
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1148
|
-
GGML_TENSOR_LOCALS(
|
|
1434
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1149
1435
|
|
|
1150
1436
|
const ggml_tensor * src3 = op->src[3];
|
|
1151
1437
|
const ggml_tensor * src4 = op->src[4];
|
|
@@ -1172,26 +1458,37 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|
|
1172
1458
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
1173
1459
|
/*.n_seqs =*/ n_seqs,
|
|
1174
1460
|
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
|
|
1461
|
+
/*.nb00 =*/ nb00,
|
|
1175
1462
|
/*.nb01 =*/ nb01,
|
|
1176
1463
|
/*.nb02 =*/ nb02,
|
|
1177
1464
|
/*.nb03 =*/ nb03,
|
|
1465
|
+
/*.nb10 =*/ nb10,
|
|
1178
1466
|
/*.nb11 =*/ nb11,
|
|
1179
1467
|
/*.nb12 =*/ nb12,
|
|
1468
|
+
/*.ns12 =*/ nb12/nb10,
|
|
1180
1469
|
/*.nb13 =*/ nb13,
|
|
1470
|
+
/*.nb20 =*/ nb20,
|
|
1181
1471
|
/*.nb21 =*/ nb21,
|
|
1472
|
+
/*.ns21 =*/ nb21/nb20,
|
|
1182
1473
|
/*.nb22 =*/ nb22,
|
|
1474
|
+
/*.ne30 =*/ ne30,
|
|
1183
1475
|
/*.nb31 =*/ nb31,
|
|
1184
1476
|
/*.nb41 =*/ nb41,
|
|
1185
1477
|
/*.nb42 =*/ nb42,
|
|
1478
|
+
/*.ns42 =*/ nb42/nb40,
|
|
1186
1479
|
/*.nb43 =*/ nb43,
|
|
1187
1480
|
/*.nb51 =*/ nb51,
|
|
1188
1481
|
/*.nb52 =*/ nb52,
|
|
1482
|
+
/*.ns52 =*/ nb52/nb50,
|
|
1189
1483
|
/*.nb53 =*/ nb53,
|
|
1484
|
+
/*.nb0 =*/ nb0,
|
|
1190
1485
|
};
|
|
1191
1486
|
|
|
1192
|
-
|
|
1487
|
+
auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
|
1488
|
+
|
|
1489
|
+
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1193
1490
|
|
|
1194
|
-
const size_t
|
|
1491
|
+
const size_t smem = pipeline.smem;
|
|
1195
1492
|
|
|
1196
1493
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1197
1494
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -1204,15 +1501,9 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|
|
1204
1501
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
|
|
1205
1502
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
|
|
1206
1503
|
|
|
1207
|
-
ggml_metal_encoder_set_threadgroup_memory_size(enc,
|
|
1504
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1208
1505
|
|
|
1209
|
-
|
|
1210
|
-
// Mamba-2
|
|
1211
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1212
|
-
} else {
|
|
1213
|
-
GGML_ASSERT(d_inner == 1);
|
|
1214
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
|
|
1215
|
-
}
|
|
1506
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
|
1216
1507
|
|
|
1217
1508
|
return 1;
|
|
1218
1509
|
}
|
|
@@ -1226,14 +1517,14 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|
|
1226
1517
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1227
1518
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1228
1519
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1229
|
-
GGML_TENSOR_LOCALS(
|
|
1520
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1230
1521
|
|
|
1231
1522
|
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
|
1232
1523
|
const int64_t T = op->src[0]->ne[2];
|
|
1233
1524
|
const int64_t C = op->ne[0];
|
|
1234
1525
|
const int64_t H = op->src[0]->ne[1];
|
|
1235
1526
|
|
|
1236
|
-
|
|
1527
|
+
auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
|
1237
1528
|
|
|
1238
1529
|
int ida = 0;
|
|
1239
1530
|
|
|
@@ -1267,32 +1558,29 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1267
1558
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1268
1559
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1269
1560
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1270
|
-
GGML_TENSOR_LOCALS(
|
|
1561
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1271
1562
|
|
|
1272
|
-
|
|
1563
|
+
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
|
1273
1564
|
|
|
1274
1565
|
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
|
|
1275
1566
|
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1283
|
-
nth *= 2;
|
|
1567
|
+
int64_t nk0 = ne00;
|
|
1568
|
+
if (ggml_is_quantized(op->src[0]->type)) {
|
|
1569
|
+
nk0 = ne00/16;
|
|
1570
|
+
} else if (ggml_is_quantized(op->type)) {
|
|
1571
|
+
nk0 = ne00/ggml_blck_size(op->type);
|
|
1284
1572
|
}
|
|
1285
1573
|
|
|
1286
|
-
nth = std::min(
|
|
1574
|
+
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1287
1575
|
|
|
1288
1576
|
// when rows are small, we can batch them together in a single threadgroup
|
|
1289
1577
|
int nrptg = 1;
|
|
1290
1578
|
|
|
1291
1579
|
// TODO: relax this constraint in the future
|
|
1292
1580
|
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
|
1293
|
-
if (nth >
|
|
1294
|
-
nrptg = (nth +
|
|
1295
|
-
nth =
|
|
1581
|
+
if (nth > nk0) {
|
|
1582
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
1583
|
+
nth = nk0;
|
|
1296
1584
|
|
|
1297
1585
|
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
1298
1586
|
nrptg--;
|
|
@@ -1300,10 +1588,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1300
1588
|
}
|
|
1301
1589
|
}
|
|
1302
1590
|
|
|
1303
|
-
nth = std::min(nth,
|
|
1591
|
+
nth = std::min<int>(nth, nk0);
|
|
1304
1592
|
|
|
1305
1593
|
ggml_metal_kargs_cpy args = {
|
|
1306
|
-
/*.
|
|
1594
|
+
/*.nk0 =*/ nk0,
|
|
1595
|
+
/*.ne00 =*/ ne00,
|
|
1307
1596
|
/*.ne01 =*/ ne01,
|
|
1308
1597
|
/*.ne02 =*/ ne02,
|
|
1309
1598
|
/*.ne03 =*/ ne03,
|
|
@@ -1321,12 +1610,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|
|
1321
1610
|
/*.nb3 =*/ nb3,
|
|
1322
1611
|
};
|
|
1323
1612
|
|
|
1613
|
+
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
|
1614
|
+
|
|
1324
1615
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1325
1616
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
1326
1617
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
1327
1618
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
1328
1619
|
|
|
1329
|
-
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
|
|
1620
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
|
|
1330
1621
|
|
|
1331
1622
|
return 1;
|
|
1332
1623
|
}
|
|
@@ -1340,7 +1631,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
|
|
1340
1631
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
1341
1632
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
1342
1633
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1343
|
-
GGML_TENSOR_LOCALS(
|
|
1634
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1344
1635
|
|
|
1345
1636
|
const int32_t * opts = op->op_params;
|
|
1346
1637
|
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
|
@@ -1376,7 +1667,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
|
|
1376
1667
|
/* .np = */ np
|
|
1377
1668
|
};
|
|
1378
1669
|
|
|
1379
|
-
|
|
1670
|
+
auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
|
|
1380
1671
|
|
|
1381
1672
|
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
|
1382
1673
|
const int ntg = (np + nth - 1) / nth;
|
|
@@ -1404,7 +1695,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1404
1695
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
1405
1696
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
1406
1697
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1407
|
-
GGML_TENSOR_LOCALS(
|
|
1698
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1408
1699
|
|
|
1409
1700
|
GGML_ASSERT(ne00 == ne10);
|
|
1410
1701
|
|
|
@@ -1485,7 +1776,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1485
1776
|
GGML_ABORT("unsupported ne11");
|
|
1486
1777
|
};
|
|
1487
1778
|
|
|
1488
|
-
|
|
1779
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
|
|
1489
1780
|
|
|
1490
1781
|
ggml_metal_kargs_mul_mv_ext args = {
|
|
1491
1782
|
/*.ne00 =*/ ne00,
|
|
@@ -1520,9 +1811,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1520
1811
|
!ggml_is_transposed(op->src[1]) &&
|
|
1521
1812
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1522
1813
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1523
|
-
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
|
1524
|
-
(
|
|
1525
|
-
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1814
|
+
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
|
|
1815
|
+
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1526
1816
|
|
|
1527
1817
|
// some Metal matrix data types require aligned pointers
|
|
1528
1818
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
@@ -1533,7 +1823,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1533
1823
|
// default: break;
|
|
1534
1824
|
//}
|
|
1535
1825
|
|
|
1536
|
-
|
|
1826
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
|
1537
1827
|
|
|
1538
1828
|
ggml_metal_kargs_mul_mm args = {
|
|
1539
1829
|
/*.ne00 =*/ ne00,
|
|
@@ -1558,18 +1848,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|
|
1558
1848
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
1559
1849
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
1560
1850
|
|
|
1561
|
-
const size_t smem =
|
|
1851
|
+
const size_t smem = pipeline.smem;
|
|
1562
1852
|
|
|
1563
1853
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1564
1854
|
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
|
|
1565
1855
|
} else {
|
|
1566
|
-
|
|
1856
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
|
1567
1857
|
|
|
1568
|
-
const int nr0 =
|
|
1569
|
-
const int nr1 =
|
|
1570
|
-
const int nsg =
|
|
1858
|
+
const int nr0 = pipeline.nr0;
|
|
1859
|
+
const int nr1 = pipeline.nr1;
|
|
1860
|
+
const int nsg = pipeline.nsg;
|
|
1571
1861
|
|
|
1572
|
-
const size_t smem =
|
|
1862
|
+
const size_t smem = pipeline.smem;
|
|
1573
1863
|
|
|
1574
1864
|
ggml_metal_kargs_mul_mv args = {
|
|
1575
1865
|
/*.ne00 =*/ ne00,
|
|
@@ -1646,7 +1936,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1646
1936
|
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
1647
1937
|
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
1648
1938
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1649
|
-
GGML_TENSOR_LOCALS(
|
|
1939
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
1650
1940
|
|
|
1651
1941
|
// src2 = ids
|
|
1652
1942
|
GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
|
|
@@ -1700,9 +1990,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1700
1990
|
nb21,
|
|
1701
1991
|
};
|
|
1702
1992
|
|
|
1703
|
-
|
|
1993
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
|
|
1704
1994
|
|
|
1705
|
-
const size_t smem =
|
|
1995
|
+
const size_t smem = pipeline.smem;
|
|
1706
1996
|
|
|
1707
1997
|
GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
1708
1998
|
|
|
@@ -1723,7 +2013,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1723
2013
|
ggml_metal_op_concurrency_reset(ctx);
|
|
1724
2014
|
|
|
1725
2015
|
{
|
|
1726
|
-
|
|
2016
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
|
1727
2017
|
|
|
1728
2018
|
ggml_metal_kargs_mul_mm_id args = {
|
|
1729
2019
|
/*.ne00 =*/ ne00,
|
|
@@ -1752,20 +2042,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1752
2042
|
ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
|
|
1753
2043
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
|
|
1754
2044
|
|
|
1755
|
-
const size_t smem =
|
|
2045
|
+
const size_t smem = pipeline.smem;
|
|
1756
2046
|
|
|
1757
2047
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
1758
2048
|
|
|
1759
2049
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
|
|
1760
2050
|
}
|
|
1761
2051
|
} else {
|
|
1762
|
-
|
|
2052
|
+
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
|
1763
2053
|
|
|
1764
|
-
const int nr0 =
|
|
1765
|
-
const int nr1 =
|
|
1766
|
-
const int nsg =
|
|
2054
|
+
const int nr0 = pipeline.nr0;
|
|
2055
|
+
const int nr1 = pipeline.nr1;
|
|
2056
|
+
const int nsg = pipeline.nsg;
|
|
1767
2057
|
|
|
1768
|
-
const size_t smem =
|
|
2058
|
+
const size_t smem = pipeline.smem;
|
|
1769
2059
|
|
|
1770
2060
|
ggml_metal_kargs_mul_mv_id args = {
|
|
1771
2061
|
/*.nei0 =*/ ne20,
|
|
@@ -1849,7 +2139,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
|
|
|
1849
2139
|
/*.nb21 =*/ nb21,
|
|
1850
2140
|
};
|
|
1851
2141
|
|
|
1852
|
-
|
|
2142
|
+
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
|
|
1853
2143
|
|
|
1854
2144
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
1855
2145
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -1875,20 +2165,118 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
|
|
|
1875
2165
|
return (ne01 < 20) && (ne00 % 32 == 0);
|
|
1876
2166
|
}
|
|
1877
2167
|
|
|
2168
|
+
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
|
2169
|
+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
2170
|
+
|
|
2171
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2172
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2173
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2174
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2175
|
+
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2176
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2177
|
+
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2178
|
+
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2179
|
+
|
|
2180
|
+
size_t res = 0;
|
|
2181
|
+
|
|
2182
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
2183
|
+
|
|
2184
|
+
// note: the non-vec kernel requires more extra memory, so always reserve for it
|
|
2185
|
+
GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
|
|
2186
|
+
|
|
2187
|
+
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2188
|
+
if (false) {
|
|
2189
|
+
// note: always reserve the padding space to avoid graph reallocations
|
|
2190
|
+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
|
2191
|
+
const bool has_kvpad = true;
|
|
2192
|
+
|
|
2193
|
+
if (has_kvpad) {
|
|
2194
|
+
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
|
2195
|
+
nb11*ne12*ne13 +
|
|
2196
|
+
nb21*ne22*ne23 +
|
|
2197
|
+
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
2198
|
+
}
|
|
2199
|
+
} else {
|
|
2200
|
+
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
|
2201
|
+
const bool has_kvpad = true;
|
|
2202
|
+
|
|
2203
|
+
if (has_kvpad) {
|
|
2204
|
+
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
|
2205
|
+
nb11*ne12*ne13 +
|
|
2206
|
+
nb21*ne22*ne23 +
|
|
2207
|
+
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
2208
|
+
}
|
|
2209
|
+
}
|
|
2210
|
+
|
|
2211
|
+
return res;
|
|
2212
|
+
}
|
|
2213
|
+
|
|
2214
|
+
size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
|
2215
|
+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
2216
|
+
|
|
2217
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2218
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2219
|
+
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2220
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2221
|
+
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2222
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2223
|
+
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2224
|
+
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2225
|
+
|
|
2226
|
+
size_t res = 0;
|
|
2227
|
+
|
|
2228
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
2229
|
+
|
|
2230
|
+
if (!has_mask) {
|
|
2231
|
+
return res;
|
|
2232
|
+
}
|
|
2233
|
+
|
|
2234
|
+
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
|
|
2235
|
+
|
|
2236
|
+
// this optimization is not useful for the vector kernels
|
|
2237
|
+
// note: always reserve the blk buffer to avoid graph reallocations
|
|
2238
|
+
//if (is_vec) {
|
|
2239
|
+
// return res;
|
|
2240
|
+
//}
|
|
2241
|
+
|
|
2242
|
+
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
|
2243
|
+
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
|
2244
|
+
|
|
2245
|
+
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
|
2246
|
+
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
|
|
2247
|
+
|
|
2248
|
+
res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
|
|
2249
|
+
|
|
2250
|
+
return res;
|
|
2251
|
+
}
|
|
2252
|
+
|
|
1878
2253
|
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
|
1879
2254
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
1880
2255
|
|
|
1881
|
-
|
|
2256
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2257
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2258
|
+
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2259
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2260
|
+
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
2261
|
+
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
2262
|
+
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
2263
|
+
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
2264
|
+
|
|
2265
|
+
size_t res = 0;
|
|
2266
|
+
|
|
2267
|
+
// note: always reserve the temp buffer to avoid graph reallocations
|
|
2268
|
+
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2269
|
+
if (true) {
|
|
2270
|
+
const int64_t nwg = 32;
|
|
2271
|
+
const int64_t ne01_max = std::min(ne01, 32);
|
|
1882
2272
|
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
2273
|
+
// temp buffer for writing the results from each workgroup
|
|
2274
|
+
// - ne20: the size of the Value head
|
|
2275
|
+
// - + 2: the S and M values for each intermediate result
|
|
2276
|
+
res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
|
|
2277
|
+
}
|
|
1887
2278
|
|
|
1888
|
-
|
|
1889
|
-
// - ne20: the size of the Value head
|
|
1890
|
-
// - + 2: the S and M values for each intermediate result
|
|
1891
|
-
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
|
2279
|
+
return res;
|
|
1892
2280
|
}
|
|
1893
2281
|
|
|
1894
2282
|
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
@@ -1910,8 +2298,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
1910
2298
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
1911
2299
|
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
|
1912
2300
|
|
|
1913
|
-
GGML_ASSERT(ne00 % 4
|
|
1914
|
-
GGML_ASSERT(ne11 % 32 == 0);
|
|
2301
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
|
1915
2302
|
|
|
1916
2303
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
1917
2304
|
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
|
@@ -1921,8 +2308,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
1921
2308
|
GGML_ASSERT(ne12 == ne22);
|
|
1922
2309
|
|
|
1923
2310
|
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
|
|
1924
|
-
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >=
|
|
1925
|
-
"the Flash-Attention Metal kernel requires the mask to be
|
|
2311
|
+
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
|
2312
|
+
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
|
1926
2313
|
|
|
1927
2314
|
float scale;
|
|
1928
2315
|
float max_bias;
|
|
@@ -1949,15 +2336,107 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
1949
2336
|
|
|
1950
2337
|
GGML_ASSERT(ne01 < 65536);
|
|
1951
2338
|
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
2339
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
2340
|
+
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
2341
|
+
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
|
|
2342
|
+
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
|
2343
|
+
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
|
2344
|
+
|
|
2345
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
2346
|
+
|
|
2347
|
+
ggml_metal_buffer_id bid_pad = bid_dst;
|
|
2348
|
+
bid_pad.offs += ggml_nbytes(op);
|
|
2349
|
+
|
|
2350
|
+
ggml_metal_buffer_id bid_blk = bid_pad;
|
|
2351
|
+
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
|
|
2352
|
+
|
|
2353
|
+
ggml_metal_buffer_id bid_tmp = bid_blk;
|
|
2354
|
+
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
|
|
2355
|
+
|
|
2356
|
+
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
2357
|
+
// half8x8 kernel
|
|
2358
|
+
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
|
|
2359
|
+
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
|
1956
2360
|
|
|
1957
2361
|
GGML_ASSERT(nqptg <= 32);
|
|
1958
2362
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
1959
2363
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
1960
2364
|
|
|
2365
|
+
bool need_sync = false;
|
|
2366
|
+
|
|
2367
|
+
const bool has_kvpad = ne11 % ncpsg != 0;
|
|
2368
|
+
|
|
2369
|
+
if (has_kvpad) {
|
|
2370
|
+
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
2371
|
+
|
|
2372
|
+
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
2373
|
+
/*.ne11 =*/ne11,
|
|
2374
|
+
/*.ne_12_2 =*/ne12,
|
|
2375
|
+
/*.ne_12_3 =*/ne13,
|
|
2376
|
+
/*.nb11 =*/nb11,
|
|
2377
|
+
/*.nb12 =*/nb12,
|
|
2378
|
+
/*.nb13 =*/nb13,
|
|
2379
|
+
/*.nb21 =*/nb21,
|
|
2380
|
+
/*.nb22 =*/nb22,
|
|
2381
|
+
/*.nb23 =*/nb23,
|
|
2382
|
+
/*.ne31 =*/ne31,
|
|
2383
|
+
/*.ne32 =*/ne32,
|
|
2384
|
+
/*.ne33 =*/ne33,
|
|
2385
|
+
/*.nb31 =*/nb31,
|
|
2386
|
+
/*.nb32 =*/nb32,
|
|
2387
|
+
/*.nb33 =*/nb33,
|
|
2388
|
+
};
|
|
2389
|
+
|
|
2390
|
+
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2391
|
+
|
|
2392
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2393
|
+
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2394
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
2395
|
+
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
2396
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
2397
|
+
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
2398
|
+
|
|
2399
|
+
assert(ne12 == ne22);
|
|
2400
|
+
assert(ne13 == ne23);
|
|
2401
|
+
|
|
2402
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2403
|
+
|
|
2404
|
+
need_sync = true;
|
|
2405
|
+
}
|
|
2406
|
+
|
|
2407
|
+
if (has_mask) {
|
|
2408
|
+
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
|
|
2409
|
+
|
|
2410
|
+
ggml_metal_kargs_flash_attn_ext_blk args0 = {
|
|
2411
|
+
/*.ne01 =*/ ne01,
|
|
2412
|
+
/*.ne30 =*/ ne30,
|
|
2413
|
+
/*.ne31 =*/ ne31,
|
|
2414
|
+
/*.ne32 =*/ ne32,
|
|
2415
|
+
/*.ne33 =*/ ne33,
|
|
2416
|
+
/*.nb31 =*/ nb31,
|
|
2417
|
+
/*.nb32 =*/ nb32,
|
|
2418
|
+
/*.nb33 =*/ nb33,
|
|
2419
|
+
};
|
|
2420
|
+
|
|
2421
|
+
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
|
2422
|
+
|
|
2423
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2424
|
+
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2425
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
|
|
2426
|
+
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
|
|
2427
|
+
|
|
2428
|
+
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
|
|
2429
|
+
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
|
|
2430
|
+
|
|
2431
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
|
2432
|
+
|
|
2433
|
+
need_sync = true;
|
|
2434
|
+
}
|
|
2435
|
+
|
|
2436
|
+
if (need_sync) {
|
|
2437
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
2438
|
+
}
|
|
2439
|
+
|
|
1961
2440
|
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
|
1962
2441
|
|
|
1963
2442
|
// 2*(2*ncpsg)
|
|
@@ -2007,6 +2486,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2007
2486
|
/*.nb21 =*/ nb21,
|
|
2008
2487
|
/*.nb22 =*/ nb22,
|
|
2009
2488
|
/*.nb23 =*/ nb23,
|
|
2489
|
+
/*.ne31 =*/ ne31,
|
|
2010
2490
|
/*.ne32 =*/ ne32,
|
|
2011
2491
|
/*.ne33 =*/ ne33,
|
|
2012
2492
|
/*.nb31 =*/ nb31,
|
|
@@ -2023,24 +2503,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2023
2503
|
/*.logit_softcap =*/ logit_softcap,
|
|
2024
2504
|
};
|
|
2025
2505
|
|
|
2026
|
-
|
|
2506
|
+
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
|
2027
2507
|
|
|
2028
2508
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2029
2509
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2030
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2031
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2032
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
if (op->src[4]) {
|
|
2039
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
2040
|
-
} else {
|
|
2041
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
2042
|
-
}
|
|
2043
|
-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
|
|
2510
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2511
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2512
|
+
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
2513
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
2514
|
+
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
2515
|
+
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
|
2516
|
+
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
|
|
2517
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
|
2044
2518
|
|
|
2045
2519
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2046
2520
|
|
|
@@ -2048,14 +2522,60 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2048
2522
|
#undef FATTN_SMEM
|
|
2049
2523
|
} else {
|
|
2050
2524
|
// half4x4 kernel
|
|
2051
|
-
const
|
|
2052
|
-
const
|
|
2053
|
-
const
|
|
2525
|
+
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
|
|
2526
|
+
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
2527
|
+
const int nkpsg = 1*ncpsg;
|
|
2054
2528
|
|
|
2055
2529
|
GGML_ASSERT(nqptg <= 32);
|
|
2056
2530
|
GGML_ASSERT(nqptg % 1 == 0);
|
|
2057
2531
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
2058
2532
|
|
|
2533
|
+
bool need_sync = false;
|
|
2534
|
+
|
|
2535
|
+
const bool has_kvpad = ne11 % ncpsg != 0;
|
|
2536
|
+
|
|
2537
|
+
if (has_kvpad) {
|
|
2538
|
+
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
2539
|
+
|
|
2540
|
+
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
2541
|
+
/*.ne11 =*/ne11,
|
|
2542
|
+
/*.ne_12_2 =*/ne12,
|
|
2543
|
+
/*.ne_12_3 =*/ne13,
|
|
2544
|
+
/*.nb11 =*/nb11,
|
|
2545
|
+
/*.nb12 =*/nb12,
|
|
2546
|
+
/*.nb13 =*/nb13,
|
|
2547
|
+
/*.nb21 =*/nb21,
|
|
2548
|
+
/*.nb22 =*/nb22,
|
|
2549
|
+
/*.nb23 =*/nb23,
|
|
2550
|
+
/*.ne31 =*/ne31,
|
|
2551
|
+
/*.ne32 =*/ne32,
|
|
2552
|
+
/*.ne33 =*/ne33,
|
|
2553
|
+
/*.nb31 =*/nb31,
|
|
2554
|
+
/*.nb32 =*/nb32,
|
|
2555
|
+
/*.nb33 =*/nb33,
|
|
2556
|
+
};
|
|
2557
|
+
|
|
2558
|
+
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
2559
|
+
|
|
2560
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2561
|
+
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
2562
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
2563
|
+
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
2564
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
2565
|
+
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
2566
|
+
|
|
2567
|
+
assert(ne12 == ne22);
|
|
2568
|
+
assert(ne13 == ne23);
|
|
2569
|
+
|
|
2570
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
2571
|
+
|
|
2572
|
+
need_sync = true;
|
|
2573
|
+
}
|
|
2574
|
+
|
|
2575
|
+
if (need_sync) {
|
|
2576
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
2577
|
+
}
|
|
2578
|
+
|
|
2059
2579
|
// ne00 + 2*ncpsg*(nsg)
|
|
2060
2580
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
2061
2581
|
// and store the soft_max values and the mask
|
|
@@ -2120,6 +2640,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2120
2640
|
/*.nb21 =*/ nb21,
|
|
2121
2641
|
/*.nb22 =*/ nb22,
|
|
2122
2642
|
/*.nb23 =*/ nb23,
|
|
2643
|
+
/*.ne31 =*/ ne31,
|
|
2123
2644
|
/*.ne32 =*/ ne32,
|
|
2124
2645
|
/*.ne33 =*/ ne33,
|
|
2125
2646
|
/*.nb31 =*/ nb31,
|
|
@@ -2136,25 +2657,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2136
2657
|
/*.logit_softcap =*/ logit_softcap,
|
|
2137
2658
|
};
|
|
2138
2659
|
|
|
2139
|
-
|
|
2660
|
+
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
|
2140
2661
|
|
|
2141
2662
|
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2142
2663
|
|
|
2143
2664
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2144
2665
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
2145
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2146
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2147
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
2148
|
-
|
|
2149
|
-
|
|
2150
|
-
} else {
|
|
2151
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
|
2152
|
-
}
|
|
2153
|
-
if (op->src[4]) {
|
|
2154
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
2155
|
-
} else {
|
|
2156
|
-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
2157
|
-
}
|
|
2666
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
2667
|
+
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
2668
|
+
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
2669
|
+
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
2670
|
+
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
2158
2671
|
|
|
2159
2672
|
const size_t smem = FATTN_SMEM(nsg);
|
|
2160
2673
|
|
|
@@ -2162,23 +2675,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2162
2675
|
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
|
2163
2676
|
|
|
2164
2677
|
if (nwg == 1) {
|
|
2678
|
+
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
|
2679
|
+
|
|
2165
2680
|
// using 1 workgroup -> write the result directly into dst
|
|
2166
|
-
ggml_metal_encoder_set_buffer(enc,
|
|
2681
|
+
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
2682
|
+
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
|
2167
2683
|
|
|
2168
2684
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2169
2685
|
|
|
2170
2686
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
2171
2687
|
} else {
|
|
2172
2688
|
// sanity checks
|
|
2689
|
+
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
|
2690
|
+
|
|
2173
2691
|
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
|
2174
2692
|
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
|
2175
2693
|
|
|
2176
|
-
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
2177
|
-
|
|
2178
2694
|
// write the results from each workgroup into a temp buffer
|
|
2179
|
-
|
|
2180
|
-
bid_tmp
|
|
2181
|
-
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
|
2695
|
+
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
2696
|
+
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
|
2182
2697
|
|
|
2183
2698
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
2184
2699
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
@@ -2194,7 +2709,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
2194
2709
|
nrows,
|
|
2195
2710
|
};
|
|
2196
2711
|
|
|
2197
|
-
|
|
2712
|
+
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
|
|
2198
2713
|
|
|
2199
2714
|
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
2200
2715
|
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
@@ -2326,7 +2841,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
2326
2841
|
// the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
|
|
2327
2842
|
bid_src1.offs = 0;
|
|
2328
2843
|
|
|
2329
|
-
|
|
2844
|
+
struct ggml_metal_pipeline_with_params pipeline;
|
|
2330
2845
|
|
|
2331
2846
|
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2332
2847
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
@@ -2385,7 +2900,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2385
2900
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2386
2901
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2387
2902
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2388
|
-
GGML_TENSOR_LOCALS(
|
|
2903
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2389
2904
|
|
|
2390
2905
|
float eps;
|
|
2391
2906
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
@@ -2399,7 +2914,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2399
2914
|
/*.eps =*/ eps,
|
|
2400
2915
|
};
|
|
2401
2916
|
|
|
2402
|
-
|
|
2917
|
+
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
|
2403
2918
|
|
|
2404
2919
|
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
2405
2920
|
nth *= 2;
|
|
@@ -2408,7 +2923,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2408
2923
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2409
2924
|
nth = std::min(nth, ne00/4);
|
|
2410
2925
|
|
|
2411
|
-
const size_t smem =
|
|
2926
|
+
const size_t smem = pipeline.smem;
|
|
2412
2927
|
|
|
2413
2928
|
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
2414
2929
|
|
|
@@ -2433,7 +2948,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2433
2948
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2434
2949
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2435
2950
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2436
|
-
GGML_TENSOR_LOCALS(
|
|
2951
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2437
2952
|
|
|
2438
2953
|
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
|
|
2439
2954
|
|
|
@@ -2451,7 +2966,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2451
2966
|
/*.eps =*/ eps,
|
|
2452
2967
|
};
|
|
2453
2968
|
|
|
2454
|
-
|
|
2969
|
+
auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
|
|
2455
2970
|
|
|
2456
2971
|
int nth = 32; // SIMD width
|
|
2457
2972
|
//while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
@@ -2461,7 +2976,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2461
2976
|
//nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2462
2977
|
//nth = std::min(nth, ne00/4);
|
|
2463
2978
|
|
|
2464
|
-
const size_t smem =
|
|
2979
|
+
const size_t smem = pipeline.smem;
|
|
2465
2980
|
|
|
2466
2981
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2467
2982
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2488,7 +3003,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2488
3003
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2489
3004
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2490
3005
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2491
|
-
GGML_TENSOR_LOCALS(
|
|
3006
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2492
3007
|
|
|
2493
3008
|
float eps;
|
|
2494
3009
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
@@ -2586,7 +3101,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2586
3101
|
}
|
|
2587
3102
|
}
|
|
2588
3103
|
|
|
2589
|
-
|
|
3104
|
+
auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
|
2590
3105
|
|
|
2591
3106
|
int nth = 32; // SIMD width
|
|
2592
3107
|
|
|
@@ -2597,7 +3112,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
2597
3112
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2598
3113
|
nth = std::min(nth, args.ne00_t);
|
|
2599
3114
|
|
|
2600
|
-
const size_t smem =
|
|
3115
|
+
const size_t smem = pipeline.smem;
|
|
2601
3116
|
|
|
2602
3117
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2603
3118
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2624,7 +3139,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
|
|
2624
3139
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2625
3140
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2626
3141
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2627
|
-
GGML_TENSOR_LOCALS(
|
|
3142
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2628
3143
|
|
|
2629
3144
|
// make sure we have one or more position id(ne10) per token(ne02)
|
|
2630
3145
|
GGML_ASSERT(ne10 % ne02 == 0);
|
|
@@ -2688,9 +3203,10 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
|
|
2688
3203
|
/* sect_1 =*/ sect_1,
|
|
2689
3204
|
/* sect_2 =*/ sect_2,
|
|
2690
3205
|
/* sect_3 =*/ sect_3,
|
|
3206
|
+
/* src2 =*/ op->src[2] != nullptr,
|
|
2691
3207
|
};
|
|
2692
3208
|
|
|
2693
|
-
|
|
3209
|
+
auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
|
2694
3210
|
|
|
2695
3211
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2696
3212
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2717,7 +3233,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
|
|
2717
3233
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2718
3234
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2719
3235
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2720
|
-
GGML_TENSOR_LOCALS(
|
|
3236
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2721
3237
|
|
|
2722
3238
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
2723
3239
|
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
|
@@ -2762,7 +3278,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
|
|
2762
3278
|
/*.KHW =*/ KH * KW,
|
|
2763
3279
|
};
|
|
2764
3280
|
|
|
2765
|
-
|
|
3281
|
+
auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
|
|
2766
3282
|
|
|
2767
3283
|
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
2768
3284
|
|
|
@@ -2778,6 +3294,84 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
|
|
2778
3294
|
return 1;
|
|
2779
3295
|
}
|
|
2780
3296
|
|
|
3297
|
+
int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
|
|
3298
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3299
|
+
|
|
3300
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
3301
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
3302
|
+
|
|
3303
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3304
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3305
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3306
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3307
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3308
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3309
|
+
|
|
3310
|
+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
3311
|
+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
|
3312
|
+
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
|
3313
|
+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
|
3314
|
+
|
|
3315
|
+
const int32_t s0 = ((const int32_t *) op->op_params)[0];
|
|
3316
|
+
const int32_t s1 = ((const int32_t *) op->op_params)[1];
|
|
3317
|
+
const int32_t p0 = ((const int32_t *) op->op_params)[2];
|
|
3318
|
+
const int32_t p1 = ((const int32_t *) op->op_params)[3];
|
|
3319
|
+
const int32_t d0 = ((const int32_t *) op->op_params)[4];
|
|
3320
|
+
const int32_t d1 = ((const int32_t *) op->op_params)[5];
|
|
3321
|
+
|
|
3322
|
+
ggml_metal_kargs_conv_2d args = {
|
|
3323
|
+
/*.nb00 =*/ nb00,
|
|
3324
|
+
/*.nb01 =*/ nb01,
|
|
3325
|
+
/*.nb02 =*/ nb02,
|
|
3326
|
+
/*.nb03 =*/ nb03,
|
|
3327
|
+
/*.nb10 =*/ nb10,
|
|
3328
|
+
/*.nb11 =*/ nb11,
|
|
3329
|
+
/*.nb12 =*/ nb12,
|
|
3330
|
+
/*.nb13 =*/ nb13,
|
|
3331
|
+
/*.nb0 =*/ nb0,
|
|
3332
|
+
/*.nb1 =*/ nb1,
|
|
3333
|
+
/*.nb2 =*/ nb2,
|
|
3334
|
+
/*.nb3 =*/ nb3,
|
|
3335
|
+
/*.IW =*/ ne10,
|
|
3336
|
+
/*.IH =*/ ne11,
|
|
3337
|
+
/*.KW =*/ ne00,
|
|
3338
|
+
/*.KH =*/ ne01,
|
|
3339
|
+
/*.IC =*/ ne02,
|
|
3340
|
+
/*.OC =*/ ne03,
|
|
3341
|
+
/*.OW =*/ ne0,
|
|
3342
|
+
/*.OH =*/ ne1,
|
|
3343
|
+
/*.N =*/ ne3,
|
|
3344
|
+
/*.s0 =*/ s0,
|
|
3345
|
+
/*.s1 =*/ s1,
|
|
3346
|
+
/*.p0 =*/ p0,
|
|
3347
|
+
/*.p1 =*/ p1,
|
|
3348
|
+
/*.d0 =*/ d0,
|
|
3349
|
+
/*.d1 =*/ d1,
|
|
3350
|
+
};
|
|
3351
|
+
|
|
3352
|
+
auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
|
3353
|
+
|
|
3354
|
+
int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
|
|
3355
|
+
nth = std::min(nth, 256);
|
|
3356
|
+
nth = std::max(nth, 1);
|
|
3357
|
+
|
|
3358
|
+
const uint64_t n_out = ggml_nelements(op);
|
|
3359
|
+
|
|
3360
|
+
uint64_t tg = (n_out + nth - 1)/nth;
|
|
3361
|
+
tg = std::max<uint64_t>(tg, 1);
|
|
3362
|
+
tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
|
|
3363
|
+
|
|
3364
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3365
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3366
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3367
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3368
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
3369
|
+
|
|
3370
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
|
|
3371
|
+
|
|
3372
|
+
return 1;
|
|
3373
|
+
}
|
|
3374
|
+
|
|
2781
3375
|
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|
2782
3376
|
ggml_tensor * op = ctx->node(idx);
|
|
2783
3377
|
|
|
@@ -2789,7 +3383,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2789
3383
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
2790
3384
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
2791
3385
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2792
|
-
GGML_TENSOR_LOCALS(
|
|
3386
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2793
3387
|
|
|
2794
3388
|
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
2795
3389
|
|
|
@@ -2810,7 +3404,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2810
3404
|
/*.nb1 =*/ nb1,
|
|
2811
3405
|
};
|
|
2812
3406
|
|
|
2813
|
-
|
|
3407
|
+
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
|
|
2814
3408
|
|
|
2815
3409
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
2816
3410
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -2823,6 +3417,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2823
3417
|
return 1;
|
|
2824
3418
|
}
|
|
2825
3419
|
|
|
3420
|
+
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
|
3421
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3422
|
+
|
|
3423
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
3424
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
3425
|
+
|
|
3426
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3427
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3428
|
+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
3429
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
3430
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3431
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3432
|
+
|
|
3433
|
+
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
|
3434
|
+
|
|
3435
|
+
const int32_t IC = op->src[1]->ne[2];
|
|
3436
|
+
const int32_t IH = op->src[1]->ne[1];
|
|
3437
|
+
const int32_t IW = op->src[1]->ne[0];
|
|
3438
|
+
|
|
3439
|
+
const int32_t KH = op->src[0]->ne[1];
|
|
3440
|
+
const int32_t KW = op->src[0]->ne[0];
|
|
3441
|
+
|
|
3442
|
+
const int32_t OW = op->ne[0];
|
|
3443
|
+
const int32_t OH = op->ne[1];
|
|
3444
|
+
const int32_t OC = op->ne[2];
|
|
3445
|
+
|
|
3446
|
+
ggml_metal_kargs_conv_transpose_2d args = {
|
|
3447
|
+
/*.IC =*/ IC,
|
|
3448
|
+
/*.IH =*/ IH,
|
|
3449
|
+
/*.IW =*/ IW,
|
|
3450
|
+
/*.KH =*/ KH,
|
|
3451
|
+
/*.KW =*/ KW,
|
|
3452
|
+
/*.OC =*/ OC,
|
|
3453
|
+
/*.s0 =*/ s0,
|
|
3454
|
+
/*.nb0 =*/ nb0,
|
|
3455
|
+
/*.nb1 =*/ nb1,
|
|
3456
|
+
/*.nb2 =*/ nb2,
|
|
3457
|
+
};
|
|
3458
|
+
|
|
3459
|
+
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
|
3460
|
+
|
|
3461
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3462
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3463
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
3464
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
3465
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
|
3466
|
+
|
|
3467
|
+
// Metal requires buffer size to be multiple of 16 bytes
|
|
3468
|
+
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
|
|
3469
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3470
|
+
|
|
3471
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
|
|
3472
|
+
|
|
3473
|
+
return 1;
|
|
3474
|
+
}
|
|
3475
|
+
|
|
2826
3476
|
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
|
2827
3477
|
ggml_tensor * op = ctx->node(idx);
|
|
2828
3478
|
|
|
@@ -2832,7 +3482,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
|
|
2832
3482
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2833
3483
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2834
3484
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2835
|
-
GGML_TENSOR_LOCALS(
|
|
3485
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2836
3486
|
|
|
2837
3487
|
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
|
2838
3488
|
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
|
@@ -2862,7 +3512,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
|
|
2862
3512
|
/*.sf3 =*/ sf3
|
|
2863
3513
|
};
|
|
2864
3514
|
|
|
2865
|
-
|
|
3515
|
+
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
|
2866
3516
|
|
|
2867
3517
|
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
2868
3518
|
|
|
@@ -2885,7 +3535,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
|
|
2885
3535
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2886
3536
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2887
3537
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2888
|
-
GGML_TENSOR_LOCALS(
|
|
3538
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2889
3539
|
|
|
2890
3540
|
ggml_metal_kargs_pad args = {
|
|
2891
3541
|
/*.ne00 =*/ ne00,
|
|
@@ -2906,7 +3556,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
|
|
2906
3556
|
/*.nb3 =*/ nb3
|
|
2907
3557
|
};
|
|
2908
3558
|
|
|
2909
|
-
|
|
3559
|
+
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
|
2910
3560
|
|
|
2911
3561
|
const int nth = std::min(1024, ne0);
|
|
2912
3562
|
|
|
@@ -2929,7 +3579,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2929
3579
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
2930
3580
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
2931
3581
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2932
|
-
GGML_TENSOR_LOCALS(
|
|
3582
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2933
3583
|
|
|
2934
3584
|
ggml_metal_kargs_pad_reflect_1d args = {
|
|
2935
3585
|
/*.ne00 =*/ ne00,
|
|
@@ -2952,7 +3602,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
|
|
2952
3602
|
/*.p1 =*/ ((const int32_t *)(op->op_params))[1]
|
|
2953
3603
|
};
|
|
2954
3604
|
|
|
2955
|
-
|
|
3605
|
+
auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
|
|
2956
3606
|
|
|
2957
3607
|
const int nth = std::min(1024, ne0);
|
|
2958
3608
|
|
|
@@ -2973,7 +3623,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
|
|
2973
3623
|
ggml_metal_encoder_t enc = ctx->enc;
|
|
2974
3624
|
|
|
2975
3625
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
2976
|
-
GGML_TENSOR_LOCALS(
|
|
3626
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
2977
3627
|
|
|
2978
3628
|
float start;
|
|
2979
3629
|
float step;
|
|
@@ -2989,13 +3639,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
|
|
2989
3639
|
|
|
2990
3640
|
const int nth = std::min(1024, ne0);
|
|
2991
3641
|
|
|
2992
|
-
|
|
2993
|
-
|
|
2994
|
-
//[encoder setComputePipelineState:pipeline];
|
|
2995
|
-
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
|
2996
|
-
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
|
2997
|
-
|
|
2998
|
-
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
3642
|
+
auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
|
2999
3643
|
|
|
3000
3644
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3001
3645
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3015,7 +3659,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
|
|
3015
3659
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3016
3660
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3017
3661
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3018
|
-
GGML_TENSOR_LOCALS(
|
|
3662
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3019
3663
|
|
|
3020
3664
|
const int dim = op->op_params[0];
|
|
3021
3665
|
const int max_period = op->op_params[1];
|
|
@@ -3026,7 +3670,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
|
|
3026
3670
|
/*.max_period =*/ max_period,
|
|
3027
3671
|
};
|
|
3028
3672
|
|
|
3029
|
-
|
|
3673
|
+
auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
|
|
3030
3674
|
|
|
3031
3675
|
const int nth = std::max(1, std::min(1024, dim/2));
|
|
3032
3676
|
|
|
@@ -3049,14 +3693,14 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
|
|
3049
3693
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3050
3694
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3051
3695
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3052
|
-
GGML_TENSOR_LOCALS(
|
|
3696
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3053
3697
|
|
|
3054
3698
|
ggml_metal_kargs_argmax args = {
|
|
3055
3699
|
/*.ne00 = */ ne00,
|
|
3056
3700
|
/*.nb01 = */ nb01,
|
|
3057
3701
|
};
|
|
3058
3702
|
|
|
3059
|
-
|
|
3703
|
+
auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
|
|
3060
3704
|
|
|
3061
3705
|
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
3062
3706
|
|
|
@@ -3065,7 +3709,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
|
|
3065
3709
|
nth *= 2;
|
|
3066
3710
|
}
|
|
3067
3711
|
|
|
3068
|
-
const size_t smem =
|
|
3712
|
+
const size_t smem = pipeline.smem;
|
|
3069
3713
|
|
|
3070
3714
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3071
3715
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
@@ -3085,38 +3729,215 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
|
|
3085
3729
|
ggml_metal_library_t lib = ctx->lib;
|
|
3086
3730
|
ggml_metal_encoder_t enc = ctx->enc;
|
|
3087
3731
|
|
|
3732
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
3733
|
+
|
|
3088
3734
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3089
3735
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3090
3736
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3091
|
-
GGML_TENSOR_LOCALS(
|
|
3737
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3738
|
+
|
|
3739
|
+
auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
|
3092
3740
|
|
|
3093
3741
|
// bitonic sort requires the number of elements to be power of 2
|
|
3094
|
-
|
|
3095
|
-
while (
|
|
3096
|
-
|
|
3742
|
+
int nth = 1;
|
|
3743
|
+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
3744
|
+
nth *= 2;
|
|
3097
3745
|
}
|
|
3098
3746
|
|
|
3099
|
-
|
|
3100
|
-
|
|
3101
|
-
const int64_t nrows = ggml_nrows(op->src[0]);
|
|
3747
|
+
const int npr = (ne00 + nth - 1)/nth;
|
|
3102
3748
|
|
|
3103
3749
|
// Metal kernels require the buffer size to be multiple of 16 bytes
|
|
3104
3750
|
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
|
3105
|
-
const size_t smem = GGML_PAD(
|
|
3751
|
+
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
|
|
3752
|
+
|
|
3753
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
3754
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
3755
|
+
|
|
3756
|
+
ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
3757
|
+
bid_tmp.offs += ggml_nbytes(op);
|
|
3758
|
+
|
|
3759
|
+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
|
3760
|
+
std::swap(bid_dst, bid_tmp);
|
|
3761
|
+
}
|
|
3106
3762
|
|
|
3107
3763
|
ggml_metal_kargs_argsort args = {
|
|
3108
|
-
/*.
|
|
3109
|
-
/*.
|
|
3764
|
+
/*.ne00 =*/ ne00,
|
|
3765
|
+
/*.ne01 =*/ ne01,
|
|
3766
|
+
/*.ne02 =*/ ne02,
|
|
3767
|
+
/*.ne03 =*/ ne03,
|
|
3768
|
+
/*.nb00 =*/ nb00,
|
|
3769
|
+
/*.nb01 =*/ nb01,
|
|
3770
|
+
/*.nb02 =*/ nb02,
|
|
3771
|
+
/*.nb03 =*/ nb03,
|
|
3772
|
+
/*.ne0 =*/ ne0,
|
|
3773
|
+
/*.ne1 =*/ ne1,
|
|
3774
|
+
/*.ne2 =*/ ne2,
|
|
3775
|
+
/*.ne3 =*/ ne3,
|
|
3776
|
+
/*.top_k =*/ nth,
|
|
3110
3777
|
};
|
|
3111
3778
|
|
|
3112
3779
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3113
3780
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3114
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
3115
|
-
ggml_metal_encoder_set_buffer (enc,
|
|
3781
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3782
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3116
3783
|
|
|
3117
3784
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3118
3785
|
|
|
3119
|
-
ggml_metal_encoder_dispatch_threadgroups(enc,
|
|
3786
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
|
3787
|
+
|
|
3788
|
+
auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
|
3789
|
+
|
|
3790
|
+
int len = nth;
|
|
3791
|
+
|
|
3792
|
+
while (len < ne00) {
|
|
3793
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
3794
|
+
|
|
3795
|
+
ggml_metal_kargs_argsort_merge args_merge = {
|
|
3796
|
+
/*.ne00 =*/ ne00,
|
|
3797
|
+
/*.ne01 =*/ ne01,
|
|
3798
|
+
/*.ne02 =*/ ne02,
|
|
3799
|
+
/*.ne03 =*/ ne03,
|
|
3800
|
+
/*.nb00 =*/ nb00,
|
|
3801
|
+
/*.nb01 =*/ nb01,
|
|
3802
|
+
/*.nb02 =*/ nb02,
|
|
3803
|
+
/*.nb03 =*/ nb03,
|
|
3804
|
+
/*.ne0 =*/ ne0,
|
|
3805
|
+
/*.ne1 =*/ ne1,
|
|
3806
|
+
/*.ne2 =*/ ne2,
|
|
3807
|
+
/*.ne3 =*/ ne3,
|
|
3808
|
+
/*.top_k =*/ ne00,
|
|
3809
|
+
/*.len =*/ len,
|
|
3810
|
+
};
|
|
3811
|
+
|
|
3812
|
+
// merges per row
|
|
3813
|
+
const int nm = (ne00 + 2*len - 1) / (2*len);
|
|
3814
|
+
|
|
3815
|
+
const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
|
|
3816
|
+
|
|
3817
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
|
3818
|
+
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
|
3819
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3820
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3821
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
3822
|
+
|
|
3823
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
|
3824
|
+
|
|
3825
|
+
std::swap(bid_dst, bid_tmp);
|
|
3826
|
+
|
|
3827
|
+
len <<= 1;
|
|
3828
|
+
}
|
|
3829
|
+
|
|
3830
|
+
return 1;
|
|
3831
|
+
}
|
|
3832
|
+
|
|
3833
|
+
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
|
3834
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3835
|
+
|
|
3836
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
3837
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
3838
|
+
|
|
3839
|
+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
|
3840
|
+
|
|
3841
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3842
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3843
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3844
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3845
|
+
|
|
3846
|
+
auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
|
|
3847
|
+
|
|
3848
|
+
// bitonic sort requires the number of elements to be power of 2
|
|
3849
|
+
int nth = 1;
|
|
3850
|
+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
3851
|
+
nth *= 2;
|
|
3852
|
+
}
|
|
3853
|
+
|
|
3854
|
+
// blocks per row
|
|
3855
|
+
const int npr = (ne00 + nth - 1)/nth;
|
|
3856
|
+
|
|
3857
|
+
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
|
|
3858
|
+
|
|
3859
|
+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
3860
|
+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
3861
|
+
|
|
3862
|
+
ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
3863
|
+
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
|
|
3864
|
+
|
|
3865
|
+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
|
3866
|
+
std::swap(bid_dst, bid_tmp);
|
|
3867
|
+
}
|
|
3868
|
+
|
|
3869
|
+
const int top_k = ne0;
|
|
3870
|
+
|
|
3871
|
+
ggml_metal_kargs_argsort args = {
|
|
3872
|
+
/*.ne00 =*/ ne00,
|
|
3873
|
+
/*.ne01 =*/ ne01,
|
|
3874
|
+
/*.ne02 =*/ ne02,
|
|
3875
|
+
/*.ne03 =*/ ne03,
|
|
3876
|
+
/*.nb00 =*/ nb00,
|
|
3877
|
+
/*.nb01 =*/ nb01,
|
|
3878
|
+
/*.nb02 =*/ nb02,
|
|
3879
|
+
/*.nb03 =*/ nb03,
|
|
3880
|
+
/*.ne0 =*/ ne0,
|
|
3881
|
+
/*.ne1 =*/ ne1,
|
|
3882
|
+
/*.ne2 =*/ ne2,
|
|
3883
|
+
/*.ne3 =*/ ne3,
|
|
3884
|
+
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
|
|
3885
|
+
};
|
|
3886
|
+
|
|
3887
|
+
if (npr > 1) {
|
|
3888
|
+
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
|
|
3889
|
+
}
|
|
3890
|
+
|
|
3891
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
3892
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
3893
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3894
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3895
|
+
|
|
3896
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
3897
|
+
|
|
3898
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
|
3899
|
+
|
|
3900
|
+
auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
|
3901
|
+
|
|
3902
|
+
int len = args.top_k;
|
|
3903
|
+
|
|
3904
|
+
while (len < args.ne0) {
|
|
3905
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
3906
|
+
|
|
3907
|
+
// merges per row
|
|
3908
|
+
const int nm = (args.ne0 + 2*len - 1) / (2*len);
|
|
3909
|
+
|
|
3910
|
+
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
|
|
3911
|
+
|
|
3912
|
+
ggml_metal_kargs_argsort_merge args_merge = {
|
|
3913
|
+
/*.ne00 =*/ ne00,
|
|
3914
|
+
/*.ne01 =*/ ne01,
|
|
3915
|
+
/*.ne02 =*/ ne02,
|
|
3916
|
+
/*.ne03 =*/ ne03,
|
|
3917
|
+
/*.nb00 =*/ nb00,
|
|
3918
|
+
/*.nb01 =*/ nb01,
|
|
3919
|
+
/*.nb02 =*/ nb02,
|
|
3920
|
+
/*.nb03 =*/ nb03,
|
|
3921
|
+
/*.ne0 =*/ args.ne0,
|
|
3922
|
+
/*.ne1 =*/ ne1,
|
|
3923
|
+
/*.ne2 =*/ ne2,
|
|
3924
|
+
/*.ne3 =*/ ne3,
|
|
3925
|
+
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
|
|
3926
|
+
/*.len =*/ len,
|
|
3927
|
+
};
|
|
3928
|
+
|
|
3929
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
|
3930
|
+
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
|
3931
|
+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
3932
|
+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
|
3933
|
+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
|
3934
|
+
|
|
3935
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
|
3936
|
+
|
|
3937
|
+
std::swap(bid_dst, bid_tmp);
|
|
3938
|
+
|
|
3939
|
+
len <<= 1;
|
|
3940
|
+
}
|
|
3120
3941
|
|
|
3121
3942
|
return 1;
|
|
3122
3943
|
}
|
|
@@ -3130,7 +3951,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
|
|
3130
3951
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3131
3952
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3132
3953
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3133
|
-
GGML_TENSOR_LOCALS(
|
|
3954
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3134
3955
|
|
|
3135
3956
|
float slope;
|
|
3136
3957
|
memcpy(&slope, op->op_params, sizeof(float));
|
|
@@ -3139,7 +3960,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
|
|
3139
3960
|
/*.slope =*/ slope
|
|
3140
3961
|
};
|
|
3141
3962
|
|
|
3142
|
-
|
|
3963
|
+
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
|
3143
3964
|
|
|
3144
3965
|
int64_t n = ggml_nelements(op);
|
|
3145
3966
|
|
|
@@ -3156,3 +3977,185 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
|
|
3156
3977
|
|
|
3157
3978
|
return 1;
|
|
3158
3979
|
}
|
|
3980
|
+
|
|
3981
|
+
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
|
3982
|
+
ggml_tensor * op = ctx->node(idx);
|
|
3983
|
+
|
|
3984
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
3985
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
3986
|
+
|
|
3987
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
3988
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
3989
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
3990
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
3991
|
+
|
|
3992
|
+
ggml_metal_kargs_tri args = {
|
|
3993
|
+
/*.ne00 =*/ ne00,
|
|
3994
|
+
/*.ne01 =*/ ne01,
|
|
3995
|
+
/*.ne02 =*/ ne02,
|
|
3996
|
+
/*.ne03 =*/ ne03,
|
|
3997
|
+
/*.nb00 =*/ nb00,
|
|
3998
|
+
/*.nb01 =*/ nb01,
|
|
3999
|
+
/*.nb02 =*/ nb02,
|
|
4000
|
+
/*.nb03 =*/ nb03,
|
|
4001
|
+
/*.ne0 =*/ ne0,
|
|
4002
|
+
/*.ne1 =*/ ne1,
|
|
4003
|
+
/*.ne2 =*/ ne2,
|
|
4004
|
+
/*.ne3 =*/ ne3,
|
|
4005
|
+
/*.nb0 =*/ nb0,
|
|
4006
|
+
/*.nb1 =*/ nb1,
|
|
4007
|
+
/*.nb2 =*/ nb2,
|
|
4008
|
+
/*.nb3 =*/ nb3,
|
|
4009
|
+
};
|
|
4010
|
+
|
|
4011
|
+
auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
|
|
4012
|
+
|
|
4013
|
+
int nth = 32; // SIMD width
|
|
4014
|
+
|
|
4015
|
+
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
4016
|
+
nth *= 2;
|
|
4017
|
+
}
|
|
4018
|
+
|
|
4019
|
+
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
4020
|
+
nth = std::min(nth, ne00);
|
|
4021
|
+
|
|
4022
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4023
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
4024
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
4025
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
4026
|
+
|
|
4027
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
4028
|
+
|
|
4029
|
+
return 1;
|
|
4030
|
+
}
|
|
4031
|
+
|
|
4032
|
+
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
|
4033
|
+
ggml_tensor * op = ctx->node(idx);
|
|
4034
|
+
|
|
4035
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
4036
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
4037
|
+
|
|
4038
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
4039
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
4040
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
4041
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
4042
|
+
|
|
4043
|
+
auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
|
4044
|
+
|
|
4045
|
+
const int64_t np = ggml_nelements(op->src[0]);
|
|
4046
|
+
ggml_metal_kargs_opt_step_adamw args = {
|
|
4047
|
+
/*.np =*/ np,
|
|
4048
|
+
};
|
|
4049
|
+
|
|
4050
|
+
int ida = 0;
|
|
4051
|
+
|
|
4052
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4053
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
4054
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
4055
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
4056
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
4057
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
|
4058
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
|
4059
|
+
|
|
4060
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
4061
|
+
const int64_t n = (np + nth - 1) / nth;
|
|
4062
|
+
|
|
4063
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
4064
|
+
|
|
4065
|
+
return 1;
|
|
4066
|
+
}
|
|
4067
|
+
|
|
4068
|
+
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
|
4069
|
+
ggml_tensor * op = ctx->node(idx);
|
|
4070
|
+
|
|
4071
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
4072
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
4073
|
+
|
|
4074
|
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
4075
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
4076
|
+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
4077
|
+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
|
4078
|
+
|
|
4079
|
+
auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
|
4080
|
+
|
|
4081
|
+
const int64_t np = ggml_nelements(op->src[0]);
|
|
4082
|
+
ggml_metal_kargs_opt_step_sgd args = {
|
|
4083
|
+
/*.np =*/ np,
|
|
4084
|
+
};
|
|
4085
|
+
|
|
4086
|
+
int ida = 0;
|
|
4087
|
+
|
|
4088
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4089
|
+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
4090
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
4091
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
4092
|
+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
4093
|
+
|
|
4094
|
+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
4095
|
+
const int64_t n = (np + nth - 1) / nth;
|
|
4096
|
+
|
|
4097
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
4098
|
+
|
|
4099
|
+
return 1;
|
|
4100
|
+
}
|
|
4101
|
+
|
|
4102
|
+
int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
|
|
4103
|
+
ggml_tensor * op = ctx->node(idx);
|
|
4104
|
+
|
|
4105
|
+
ggml_metal_library_t lib = ctx->lib;
|
|
4106
|
+
ggml_metal_encoder_t enc = ctx->enc;
|
|
4107
|
+
|
|
4108
|
+
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
|
4109
|
+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
4110
|
+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
4111
|
+
|
|
4112
|
+
{
|
|
4113
|
+
ggml_metal_kargs_memset args = { /*.val =*/ 0 };
|
|
4114
|
+
|
|
4115
|
+
auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
|
|
4116
|
+
|
|
4117
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4118
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
4119
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
|
|
4120
|
+
|
|
4121
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
|
4122
|
+
}
|
|
4123
|
+
|
|
4124
|
+
ggml_metal_op_concurrency_reset(ctx);
|
|
4125
|
+
|
|
4126
|
+
{
|
|
4127
|
+
ggml_metal_kargs_count_equal args = {
|
|
4128
|
+
/*.ne00 =*/ ne00,
|
|
4129
|
+
/*.ne01 =*/ ne01,
|
|
4130
|
+
/*.ne02 =*/ ne02,
|
|
4131
|
+
/*.ne03 =*/ ne03,
|
|
4132
|
+
/*.nb00 =*/ nb00,
|
|
4133
|
+
/*.nb01 =*/ nb01,
|
|
4134
|
+
/*.nb02 =*/ nb02,
|
|
4135
|
+
/*.nb03 =*/ nb03,
|
|
4136
|
+
/*.nb10 =*/ nb10,
|
|
4137
|
+
/*.nb11 =*/ nb11,
|
|
4138
|
+
/*.nb12 =*/ nb12,
|
|
4139
|
+
/*.nb13 =*/ nb13,
|
|
4140
|
+
};
|
|
4141
|
+
|
|
4142
|
+
auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
|
|
4143
|
+
|
|
4144
|
+
const size_t smem = pipeline.smem;
|
|
4145
|
+
|
|
4146
|
+
const int nth = 32*pipeline.nsg;
|
|
4147
|
+
|
|
4148
|
+
GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
4149
|
+
|
|
4150
|
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
4151
|
+
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
|
4152
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
4153
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
4154
|
+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
|
4155
|
+
|
|
4156
|
+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
4157
|
+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
4158
|
+
}
|
|
4159
|
+
|
|
4160
|
+
return 1;
|
|
4161
|
+
}
|