whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -9,6 +9,12 @@ __embed_ggml-common.h__
|
|
|
9
9
|
|
|
10
10
|
#include <metal_stdlib>
|
|
11
11
|
|
|
12
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
13
|
+
#include <metal_tensor>
|
|
14
|
+
|
|
15
|
+
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
|
|
16
|
+
#endif
|
|
17
|
+
|
|
12
18
|
using namespace metal;
|
|
13
19
|
|
|
14
20
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
@@ -1243,6 +1249,22 @@ kernel void kernel_scale_f32_4(
|
|
|
1243
1249
|
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
|
1244
1250
|
}
|
|
1245
1251
|
|
|
1252
|
+
kernel void kernel_fill_f32(
|
|
1253
|
+
constant ggml_metal_kargs_fill & args,
|
|
1254
|
+
device const float * src0,
|
|
1255
|
+
device float * dst,
|
|
1256
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1257
|
+
dst[tpig] = args.val;
|
|
1258
|
+
}
|
|
1259
|
+
|
|
1260
|
+
kernel void kernel_fill_f32_4(
|
|
1261
|
+
constant ggml_metal_kargs_fill & args,
|
|
1262
|
+
device const float4 * src0,
|
|
1263
|
+
device float4 * dst,
|
|
1264
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1265
|
+
dst[tpig] = args.val;
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1246
1268
|
kernel void kernel_clamp_f32(
|
|
1247
1269
|
constant ggml_metal_kargs_clamp & args,
|
|
1248
1270
|
device const float * src0,
|
|
@@ -1589,6 +1611,36 @@ kernel void kernel_exp_f32_4(
|
|
|
1589
1611
|
dst[tpig] = exp(src0[tpig]);
|
|
1590
1612
|
}
|
|
1591
1613
|
|
|
1614
|
+
kernel void kernel_softplus_f32(
|
|
1615
|
+
device const float * src0,
|
|
1616
|
+
device float * dst,
|
|
1617
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1618
|
+
device const float & x = src0[tpig];
|
|
1619
|
+
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
|
1620
|
+
}
|
|
1621
|
+
|
|
1622
|
+
kernel void kernel_softplus_f32_4(
|
|
1623
|
+
device const float4 * src0,
|
|
1624
|
+
device float4 * dst,
|
|
1625
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1626
|
+
device const float4 & x = src0[tpig];
|
|
1627
|
+
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
|
1628
|
+
}
|
|
1629
|
+
|
|
1630
|
+
kernel void kernel_expm1_f32(
|
|
1631
|
+
device const float * src0,
|
|
1632
|
+
device float * dst,
|
|
1633
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1634
|
+
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
|
1635
|
+
}
|
|
1636
|
+
|
|
1637
|
+
kernel void kernel_expm1_f32_4(
|
|
1638
|
+
device const float4 * src0,
|
|
1639
|
+
device float4 * dst,
|
|
1640
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1641
|
+
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
|
1642
|
+
}
|
|
1643
|
+
|
|
1592
1644
|
kernel void kernel_reglu_f32(
|
|
1593
1645
|
constant ggml_metal_kargs_glu & args,
|
|
1594
1646
|
device const char * src0,
|
|
@@ -1723,6 +1775,55 @@ kernel void kernel_geglu_quick_f32(
|
|
|
1723
1775
|
}
|
|
1724
1776
|
}
|
|
1725
1777
|
|
|
1778
|
+
kernel void kernel_op_sum_f32(
|
|
1779
|
+
constant ggml_metal_kargs_sum & args,
|
|
1780
|
+
device const float * src0,
|
|
1781
|
+
device float * dst,
|
|
1782
|
+
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
|
1783
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1784
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1785
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1786
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1787
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1788
|
+
|
|
1789
|
+
if (args.np == 0) {
|
|
1790
|
+
return;
|
|
1791
|
+
}
|
|
1792
|
+
|
|
1793
|
+
// TODO: become function constant
|
|
1794
|
+
const uint nsg = (ntg.x + 31) / 32;
|
|
1795
|
+
|
|
1796
|
+
float sumf = 0;
|
|
1797
|
+
|
|
1798
|
+
for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
|
|
1799
|
+
sumf += src0[i0];
|
|
1800
|
+
}
|
|
1801
|
+
|
|
1802
|
+
sumf = simd_sum(sumf);
|
|
1803
|
+
|
|
1804
|
+
if (tiisg == 0) {
|
|
1805
|
+
shmem_f32[sgitg] = sumf;
|
|
1806
|
+
}
|
|
1807
|
+
|
|
1808
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1809
|
+
|
|
1810
|
+
float total = 0;
|
|
1811
|
+
|
|
1812
|
+
if (sgitg == 0) {
|
|
1813
|
+
float v = 0;
|
|
1814
|
+
|
|
1815
|
+
if (tpitg.x < nsg) {
|
|
1816
|
+
v = shmem_f32[tpitg.x];
|
|
1817
|
+
}
|
|
1818
|
+
|
|
1819
|
+
total = simd_sum(v);
|
|
1820
|
+
|
|
1821
|
+
if (tpitg.x == 0) {
|
|
1822
|
+
dst[0] = total;
|
|
1823
|
+
}
|
|
1824
|
+
}
|
|
1825
|
+
}
|
|
1826
|
+
|
|
1726
1827
|
template <bool norm>
|
|
1727
1828
|
kernel void kernel_sum_rows(
|
|
1728
1829
|
constant ggml_metal_kargs_sum_rows & args,
|
|
@@ -1778,6 +1879,186 @@ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
|
|
1778
1879
|
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
|
1779
1880
|
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
|
1780
1881
|
|
|
1882
|
+
template<typename T>
|
|
1883
|
+
kernel void kernel_cumsum_blk(
|
|
1884
|
+
constant ggml_metal_kargs_cumsum_blk & args,
|
|
1885
|
+
device const char * src0,
|
|
1886
|
+
device char * tmp,
|
|
1887
|
+
device char * dst,
|
|
1888
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1889
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1890
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1891
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1892
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1893
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1894
|
+
const int ib = tgpig[0]/args.ne01;
|
|
1895
|
+
|
|
1896
|
+
const int i00 = ib*ntg.x;
|
|
1897
|
+
const int i01 = tgpig[0]%args.ne01;
|
|
1898
|
+
const int i02 = tgpig[1];
|
|
1899
|
+
const int i03 = tgpig[2];
|
|
1900
|
+
|
|
1901
|
+
device const float * src0_row = (device const float *) (src0 +
|
|
1902
|
+
args.nb01*i01 +
|
|
1903
|
+
args.nb02*i02 +
|
|
1904
|
+
args.nb03*i03);
|
|
1905
|
+
|
|
1906
|
+
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
|
1907
|
+
|
|
1908
|
+
float v = 0.0f;
|
|
1909
|
+
|
|
1910
|
+
if (i00 + tpitg.x < args.ne00) {
|
|
1911
|
+
v = src0_row[i00 + tpitg.x];
|
|
1912
|
+
}
|
|
1913
|
+
|
|
1914
|
+
float s = simd_prefix_inclusive_sum(v);
|
|
1915
|
+
|
|
1916
|
+
if (tiisg == N_SIMDWIDTH - 1) {
|
|
1917
|
+
shmem_f32[sgitg] = s;
|
|
1918
|
+
}
|
|
1919
|
+
|
|
1920
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1921
|
+
|
|
1922
|
+
if (sgitg == 0) {
|
|
1923
|
+
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
|
|
1924
|
+
}
|
|
1925
|
+
|
|
1926
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1927
|
+
|
|
1928
|
+
s += shmem_f32[sgitg];
|
|
1929
|
+
|
|
1930
|
+
device float * dst_row = (device float *) dst +
|
|
1931
|
+
args.ne00*i01 +
|
|
1932
|
+
args.ne00*args.ne01*i02 +
|
|
1933
|
+
args.ne00*args.ne01*args.ne02*i03;
|
|
1934
|
+
|
|
1935
|
+
if (i00 + tpitg.x < args.ne00) {
|
|
1936
|
+
dst_row[i00 + tpitg.x] = s;
|
|
1937
|
+
}
|
|
1938
|
+
|
|
1939
|
+
if (args.outb && tpitg.x == ntg.x - 1) {
|
|
1940
|
+
device float * tmp_row = (device float *) tmp +
|
|
1941
|
+
args.net0*i01 +
|
|
1942
|
+
args.net0*args.net1*i02 +
|
|
1943
|
+
args.net0*args.net1*args.net2*i03;
|
|
1944
|
+
|
|
1945
|
+
tmp_row[ib] = s;
|
|
1946
|
+
}
|
|
1947
|
+
}
|
|
1948
|
+
|
|
1949
|
+
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
|
|
1950
|
+
|
|
1951
|
+
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
|
|
1952
|
+
|
|
1953
|
+
template<typename T>
|
|
1954
|
+
kernel void kernel_cumsum_add(
|
|
1955
|
+
constant ggml_metal_kargs_cumsum_add & args,
|
|
1956
|
+
device const char * tmp,
|
|
1957
|
+
device char * dst,
|
|
1958
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1959
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1960
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1961
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1962
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1963
|
+
const int ib = tgpig[0]/args.ne01;
|
|
1964
|
+
|
|
1965
|
+
if (ib == 0) {
|
|
1966
|
+
return;
|
|
1967
|
+
}
|
|
1968
|
+
|
|
1969
|
+
const int i00 = ib*ntg.x;
|
|
1970
|
+
const int i01 = tgpig[0]%args.ne01;
|
|
1971
|
+
const int i02 = tgpig[1];
|
|
1972
|
+
const int i03 = tgpig[2];
|
|
1973
|
+
|
|
1974
|
+
device const float * tmp_row = (device const float *) (tmp +
|
|
1975
|
+
args.nbt1*i01 +
|
|
1976
|
+
args.nbt2*i02 +
|
|
1977
|
+
args.nbt3*i03);
|
|
1978
|
+
|
|
1979
|
+
device float * dst_row = (device float *) dst +
|
|
1980
|
+
args.ne00*i01 +
|
|
1981
|
+
args.ne00*args.ne01*i02 +
|
|
1982
|
+
args.ne00*args.ne01*args.ne02*i03;
|
|
1983
|
+
|
|
1984
|
+
if (i00 + tpitg.x < args.ne00) {
|
|
1985
|
+
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
|
|
1986
|
+
}
|
|
1987
|
+
}
|
|
1988
|
+
|
|
1989
|
+
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
|
1990
|
+
|
|
1991
|
+
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
|
1992
|
+
|
|
1993
|
+
|
|
1994
|
+
template<uint32_t ttype>
|
|
1995
|
+
bool _ggml_vec_tri_cmp(const int i, const int r);
|
|
1996
|
+
|
|
1997
|
+
template<>
|
|
1998
|
+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
|
|
1999
|
+
return i < r;
|
|
2000
|
+
}
|
|
2001
|
+
|
|
2002
|
+
template<>
|
|
2003
|
+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
|
|
2004
|
+
return i <= r;
|
|
2005
|
+
}
|
|
2006
|
+
|
|
2007
|
+
template<>
|
|
2008
|
+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
|
|
2009
|
+
return i > r;
|
|
2010
|
+
}
|
|
2011
|
+
|
|
2012
|
+
template<>
|
|
2013
|
+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
|
|
2014
|
+
return i >= r;
|
|
2015
|
+
}
|
|
2016
|
+
|
|
2017
|
+
template<typename T, int ttype>
|
|
2018
|
+
kernel void kernel_tri(
|
|
2019
|
+
constant ggml_metal_kargs_tri & args,
|
|
2020
|
+
device const char * src0,
|
|
2021
|
+
device const char * dst,
|
|
2022
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2023
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
2024
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
2025
|
+
const int i3 = tgpig.z;
|
|
2026
|
+
const int i2 = tgpig.y;
|
|
2027
|
+
const int i1 = tgpig.x;
|
|
2028
|
+
|
|
2029
|
+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
|
2030
|
+
return;
|
|
2031
|
+
}
|
|
2032
|
+
|
|
2033
|
+
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
|
2034
|
+
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
|
2035
|
+
|
|
2036
|
+
// Each thread is a single element of the row if ne00 < max threads per
|
|
2037
|
+
// threadgroup, so this will loop once for each index that this thread is
|
|
2038
|
+
// responsible for
|
|
2039
|
+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
2040
|
+
// Use the comparison as a mask for branchless
|
|
2041
|
+
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
|
|
2042
|
+
}
|
|
2043
|
+
}
|
|
2044
|
+
|
|
2045
|
+
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
|
|
2046
|
+
|
|
2047
|
+
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
|
|
2048
|
+
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
|
|
2049
|
+
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
|
|
2050
|
+
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
|
|
2051
|
+
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
|
|
2052
|
+
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
|
|
2053
|
+
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
|
|
2054
|
+
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
|
|
2055
|
+
#if defined(GGML_METAL_HAS_BF16)
|
|
2056
|
+
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
|
|
2057
|
+
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
|
|
2058
|
+
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
|
|
2059
|
+
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
|
|
2060
|
+
#endif
|
|
2061
|
+
|
|
1781
2062
|
template<typename T>
|
|
1782
2063
|
kernel void kernel_soft_max(
|
|
1783
2064
|
constant ggml_metal_kargs_soft_max & args,
|
|
@@ -2032,124 +2313,134 @@ kernel void kernel_ssm_conv_f32_f32(
|
|
|
2032
2313
|
x[0] = sumf;
|
|
2033
2314
|
}
|
|
2034
2315
|
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
device const
|
|
2039
|
-
device
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
2316
|
+
kernel void kernel_ssm_conv_f32_f32_4(
|
|
2317
|
+
constant ggml_metal_kargs_ssm_conv & args,
|
|
2318
|
+
device const void * src0,
|
|
2319
|
+
device const void * src1,
|
|
2320
|
+
device float * dst,
|
|
2321
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2322
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2323
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2324
|
+
const int64_t ir = tgpig.x;
|
|
2325
|
+
const int64_t i2 = tgpig.y;
|
|
2326
|
+
const int64_t i3 = tgpig.z;
|
|
2327
|
+
|
|
2328
|
+
const int64_t nc = args.ne10;
|
|
2329
|
+
//const int64_t ncs = args.ne00;
|
|
2330
|
+
//const int64_t nr = args.ne01;
|
|
2331
|
+
//const int64_t n_t = args.ne1;
|
|
2332
|
+
//const int64_t n_s = args.ne2;
|
|
2053
2333
|
|
|
2054
|
-
const
|
|
2055
|
-
const
|
|
2056
|
-
|
|
2057
|
-
const int64_t i3 = tgpig.y; // current seq
|
|
2334
|
+
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
|
2335
|
+
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
|
2336
|
+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
|
2058
2337
|
|
|
2059
|
-
|
|
2060
|
-
const uint64_t nb10 = sizeof(float);
|
|
2061
|
-
const uint64_t nb20 = sizeof(float);
|
|
2338
|
+
float sumf = 0.0f;
|
|
2062
2339
|
|
|
2063
|
-
|
|
2064
|
-
|
|
2065
|
-
|
|
2066
|
-
const int64_t ng = args.n_group;
|
|
2067
|
-
const int64_t n_t = args.n_seq_tokens;
|
|
2340
|
+
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
|
2341
|
+
sumf += dot(s[i0], c[i0]);
|
|
2342
|
+
}
|
|
2068
2343
|
|
|
2069
|
-
|
|
2344
|
+
x[0] = sumf;
|
|
2345
|
+
}
|
|
2070
2346
|
|
|
2071
|
-
|
|
2347
|
+
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
|
|
2072
2348
|
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
|
|
2092
|
-
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
|
|
2096
|
-
|
|
2097
|
-
|
|
2098
|
-
|
|
2099
|
-
|
|
2100
|
-
|
|
2101
|
-
|
|
2102
|
-
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
|
2103
|
-
// compute y = sum({state * C[i] for i in range(d_state)}).
|
|
2104
|
-
// To parallelize this effectively, we first use simd_sum over each SIMD
|
|
2105
|
-
// group to compute the sum of each SIMD group, then place the result in
|
|
2106
|
-
// the SIMD group's indexed bucket in the shared memory. We then sum
|
|
2107
|
-
// over the individual group sums to compute the final sum.
|
|
2108
|
-
|
|
2109
|
-
// Computed for each thread
|
|
2110
|
-
float sumf = state * C[i0];
|
|
2111
|
-
|
|
2112
|
-
// Sum the threads in the simd group => simd sum
|
|
2113
|
-
sumf = simd_sum(sumf);
|
|
2114
|
-
|
|
2115
|
-
if (sgptg > 1) {
|
|
2116
|
-
|
|
2117
|
-
// Once per simd group, place the group sum into the shared buffer
|
|
2118
|
-
if (tiisg == 0) {
|
|
2119
|
-
shared[sgitg] = sumf;
|
|
2120
|
-
}
|
|
2349
|
+
// Batched version: each threadgroup processes multiple tokens for better efficiency
|
|
2350
|
+
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
|
|
2351
|
+
kernel void kernel_ssm_conv_f32_f32_batched(
|
|
2352
|
+
constant ggml_metal_kargs_ssm_conv & args,
|
|
2353
|
+
device const void * src0,
|
|
2354
|
+
device const void * src1,
|
|
2355
|
+
device float * dst,
|
|
2356
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2357
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2358
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2359
|
+
// tgpig.x = row index (ir)
|
|
2360
|
+
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
|
2361
|
+
// tgpig.z = sequence index (i3)
|
|
2362
|
+
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
|
2363
|
+
const short BATCH_SIZE = FC_ssm_conv_bs;
|
|
2364
|
+
|
|
2365
|
+
const int64_t ir = tgpig.x;
|
|
2366
|
+
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
|
2367
|
+
const int64_t i3 = tgpig.z;
|
|
2368
|
+
const int64_t i2_off = tpitg.x;
|
|
2369
|
+
const int64_t i2 = i2_base + i2_off;
|
|
2370
|
+
|
|
2371
|
+
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
|
2372
|
+
const int64_t n_t = args.ne1; // number of tokens
|
|
2373
|
+
|
|
2374
|
+
// Bounds check for partial batches at the end
|
|
2375
|
+
if (i2 >= n_t) {
|
|
2376
|
+
return;
|
|
2377
|
+
}
|
|
2121
2378
|
|
|
2122
|
-
|
|
2123
|
-
|
|
2124
|
-
// sum of the individual simd groups.
|
|
2125
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2379
|
+
// Load conv weights (shared across all tokens for this row)
|
|
2380
|
+
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
|
2126
2381
|
|
|
2127
|
-
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
|
|
2131
|
-
|
|
2132
|
-
sumf = shared[tiisg];
|
|
2133
|
-
}
|
|
2134
|
-
sumf = simd_sum(sumf);
|
|
2135
|
-
if (tiisg == 0) {
|
|
2136
|
-
y[0] = sumf;
|
|
2137
|
-
}
|
|
2138
|
-
}
|
|
2139
|
-
} else if (tiisg == 0) {
|
|
2140
|
-
y[0] = sumf;
|
|
2141
|
-
}
|
|
2382
|
+
// Load source for this specific token
|
|
2383
|
+
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
|
2384
|
+
|
|
2385
|
+
// Output location for this token
|
|
2386
|
+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
|
2142
2387
|
|
|
2143
|
-
|
|
2144
|
-
|
|
2388
|
+
float sumf = 0.0f;
|
|
2389
|
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
2390
|
+
sumf += s[i0] * c[i0];
|
|
2145
2391
|
}
|
|
2146
2392
|
|
|
2147
|
-
|
|
2148
|
-
|
|
2393
|
+
x[0] = sumf;
|
|
2394
|
+
}
|
|
2395
|
+
|
|
2396
|
+
kernel void kernel_ssm_conv_f32_f32_batched_4(
|
|
2397
|
+
constant ggml_metal_kargs_ssm_conv & args,
|
|
2398
|
+
device const void * src0,
|
|
2399
|
+
device const void * src1,
|
|
2400
|
+
device float * dst,
|
|
2401
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2402
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2403
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2404
|
+
// tgpig.x = row index (ir)
|
|
2405
|
+
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
|
2406
|
+
// tgpig.z = sequence index (i3)
|
|
2407
|
+
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
|
2408
|
+
const short BATCH_SIZE = FC_ssm_conv_bs;
|
|
2409
|
+
|
|
2410
|
+
const int64_t ir = tgpig.x;
|
|
2411
|
+
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
|
2412
|
+
const int64_t i3 = tgpig.z;
|
|
2413
|
+
const int64_t i2_off = tpitg.x;
|
|
2414
|
+
const int64_t i2 = i2_base + i2_off;
|
|
2415
|
+
|
|
2416
|
+
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
|
2417
|
+
const int64_t n_t = args.ne1; // number of tokens
|
|
2418
|
+
|
|
2419
|
+
// Bounds check for partial batches at the end
|
|
2420
|
+
if (i2 >= n_t) {
|
|
2421
|
+
return;
|
|
2422
|
+
}
|
|
2423
|
+
|
|
2424
|
+
// Load conv weights (shared across all tokens for this row)
|
|
2425
|
+
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
|
2426
|
+
|
|
2427
|
+
// Load source for this specific token
|
|
2428
|
+
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
|
2429
|
+
|
|
2430
|
+
// Output location for this token
|
|
2431
|
+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
|
2432
|
+
|
|
2433
|
+
float sumf = 0.0f;
|
|
2434
|
+
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
|
2435
|
+
sumf += dot(s[i0], c[i0]);
|
|
2436
|
+
}
|
|
2437
|
+
|
|
2438
|
+
x[0] = sumf;
|
|
2149
2439
|
}
|
|
2150
2440
|
|
|
2151
2441
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
2152
|
-
|
|
2442
|
+
// Optimized version: reduces redundant memory loads by having one thread load shared values
|
|
2443
|
+
kernel void kernel_ssm_scan_f32(
|
|
2153
2444
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
2154
2445
|
device const void * src0,
|
|
2155
2446
|
device const void * src1,
|
|
@@ -2160,103 +2451,111 @@ kernel void kernel_ssm_scan_group_f32(
|
|
|
2160
2451
|
device const void * src6,
|
|
2161
2452
|
device float * dst,
|
|
2162
2453
|
threadgroup float * shared [[threadgroup(0)]],
|
|
2163
|
-
uint3
|
|
2164
|
-
|
|
2165
|
-
ushort
|
|
2166
|
-
ushort
|
|
2167
|
-
ushort
|
|
2168
|
-
uint3
|
|
2454
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2455
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
2456
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
2457
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
2458
|
+
ushort sgptg[[simdgroups_per_threadgroup]],
|
|
2459
|
+
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
2460
|
+
constexpr short NW = N_SIMDWIDTH;
|
|
2169
2461
|
|
|
2170
|
-
|
|
2171
|
-
|
|
2172
|
-
|
|
2173
|
-
|
|
2462
|
+
// Shared memory layout:
|
|
2463
|
+
// [0..sgptg*NW-1]: partial sums for reduction (existing)
|
|
2464
|
+
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
|
|
2465
|
+
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
|
|
2466
|
+
threadgroup float * shared_sums = shared;
|
|
2467
|
+
threadgroup float * shared_x_dt = shared + sgptg * NW;
|
|
2468
|
+
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
|
|
2469
|
+
|
|
2470
|
+
shared_sums[tpitg.x] = 0.0f;
|
|
2174
2471
|
|
|
2175
|
-
const
|
|
2176
|
-
const
|
|
2177
|
-
const
|
|
2472
|
+
const int32_t i0 = tpitg.x;
|
|
2473
|
+
const int32_t i1 = tgpig.x;
|
|
2474
|
+
const int32_t ir = tgpig.y; // current head
|
|
2475
|
+
const int32_t i3 = tgpig.z; // current seq
|
|
2178
2476
|
|
|
2179
|
-
const
|
|
2180
|
-
const
|
|
2181
|
-
const
|
|
2182
|
-
const
|
|
2183
|
-
const
|
|
2477
|
+
const int32_t nc = args.d_state;
|
|
2478
|
+
const int32_t nr = args.d_inner;
|
|
2479
|
+
const int32_t nh = args.n_head;
|
|
2480
|
+
const int32_t ng = args.n_group;
|
|
2481
|
+
const int32_t n_t = args.n_seq_tokens;
|
|
2184
2482
|
|
|
2185
|
-
const
|
|
2483
|
+
const int32_t s_off = args.s_off;
|
|
2186
2484
|
|
|
2187
2485
|
device const int32_t * ids = (device const int32_t *) src6;
|
|
2188
2486
|
|
|
2189
2487
|
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
2190
2488
|
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
2191
|
-
|
|
2192
|
-
const
|
|
2489
|
+
|
|
2490
|
+
const int32_t i = i0 + i1*nc;
|
|
2491
|
+
const int32_t g = ir / (nh / ng); // repeat_interleave
|
|
2492
|
+
|
|
2193
2493
|
float s0 = s0_buff[i];
|
|
2194
|
-
float s =
|
|
2195
|
-
|
|
2196
|
-
device const float * A
|
|
2197
|
-
|
|
2198
|
-
|
|
2199
|
-
|
|
2200
|
-
device const float *
|
|
2201
|
-
device
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2215
|
-
|
|
2216
|
-
|
|
2217
|
-
|
|
2218
|
-
|
|
2219
|
-
|
|
2220
|
-
|
|
2221
|
-
// To parallelize this effectively, we first use simd_sum over each SIMD
|
|
2222
|
-
// group to compute the sum of each SIMD group, then place the result in
|
|
2223
|
-
// the SIMD group's indexed bucket in the shared memory. We then sum
|
|
2224
|
-
// over the individual group sums to compute the final sum.
|
|
2225
|
-
|
|
2226
|
-
// Computed for each thread
|
|
2227
|
-
float sumf = state * C[i0];
|
|
2228
|
-
|
|
2229
|
-
// Sum the threads in the simd group => simd sum
|
|
2230
|
-
sumf = simd_sum(sumf);
|
|
2231
|
-
|
|
2232
|
-
// Once per simd group, place the group sum into the shared buffer
|
|
2233
|
-
if (tiisg == 0) {
|
|
2234
|
-
shared[sgitg] = sumf;
|
|
2494
|
+
float s = 0.0f;
|
|
2495
|
+
|
|
2496
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
|
|
2497
|
+
|
|
2498
|
+
const float A0 = A[i0%args.ne30];
|
|
2499
|
+
|
|
2500
|
+
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
|
|
2501
|
+
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
|
|
2502
|
+
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
|
|
2503
|
+
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
|
|
2504
|
+
|
|
2505
|
+
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
|
|
2506
|
+
|
|
2507
|
+
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
|
2508
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2509
|
+
|
|
2510
|
+
// Pre-compute x_dt and dA for this batch of tokens
|
|
2511
|
+
// Only first sgptg threads do the loads and expensive math
|
|
2512
|
+
if (i0 < sgptg && i2 + i0 < n_t) {
|
|
2513
|
+
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
|
|
2514
|
+
device const float * x_t = x + i0 * args.ns12;
|
|
2515
|
+
device const float * dt_t = dt + i0 * args.ns21;
|
|
2516
|
+
|
|
2517
|
+
const float dt0 = dt_t[0];
|
|
2518
|
+
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
|
2519
|
+
shared_x_dt[i0] = x_t[0] * dtsp;
|
|
2520
|
+
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
|
|
2235
2521
|
}
|
|
2236
2522
|
|
|
2237
|
-
// Wait for all threads in the threadgroup to reach this point. This
|
|
2238
|
-
// ensures that all elements of the shared buffer are populated with the
|
|
2239
|
-
// sum of the individual simd groups.
|
|
2240
2523
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2241
2524
|
|
|
2242
|
-
|
|
2243
|
-
|
|
2244
|
-
|
|
2245
|
-
|
|
2246
|
-
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2525
|
+
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
|
2526
|
+
const float x_dt = shared_x_dt[t];
|
|
2527
|
+
const float dA = exp(shared_dA[t] * A0);
|
|
2528
|
+
|
|
2529
|
+
s = (s0 * dA) + (B[i0] * x_dt);
|
|
2530
|
+
|
|
2531
|
+
const float sumf = simd_sum(s * C[i0]);
|
|
2532
|
+
|
|
2250
2533
|
if (tiisg == 0) {
|
|
2251
|
-
|
|
2534
|
+
shared_sums[t*NW + sgitg] = sumf;
|
|
2252
2535
|
}
|
|
2536
|
+
|
|
2537
|
+
// recurse
|
|
2538
|
+
s0 = s;
|
|
2539
|
+
|
|
2540
|
+
B += args.ns42;
|
|
2541
|
+
C += args.ns52;
|
|
2542
|
+
}
|
|
2543
|
+
|
|
2544
|
+
// Advance pointers for next batch
|
|
2545
|
+
x += sgptg * args.ns12;
|
|
2546
|
+
dt += sgptg * args.ns21;
|
|
2547
|
+
|
|
2548
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2549
|
+
|
|
2550
|
+
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
|
|
2551
|
+
|
|
2552
|
+
if (tiisg == 0 && i2 + sgitg < n_t) {
|
|
2553
|
+
y[sgitg*nh*nr] = sumf;
|
|
2253
2554
|
}
|
|
2254
2555
|
|
|
2255
|
-
|
|
2256
|
-
s0 = s;
|
|
2556
|
+
y += sgptg*nh*nr;
|
|
2257
2557
|
}
|
|
2258
2558
|
|
|
2259
|
-
// Assign the final state to the output buffer
|
|
2260
2559
|
s_buff[i] = s;
|
|
2261
2560
|
}
|
|
2262
2561
|
|
|
@@ -3761,6 +4060,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_
|
|
|
3761
4060
|
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
|
|
3762
4061
|
#endif
|
|
3763
4062
|
|
|
4063
|
+
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
|
|
4064
|
+
|
|
3764
4065
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
3765
4066
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
3766
4067
|
return 1.0f - min(1.0f, max(0.0f, y));
|
|
@@ -3830,7 +4131,7 @@ kernel void kernel_rope_norm(
|
|
|
3830
4131
|
|
|
3831
4132
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
|
3832
4133
|
|
|
3833
|
-
const float freq_factor = src2
|
|
4134
|
+
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
|
3834
4135
|
|
|
3835
4136
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
|
3836
4137
|
|
|
@@ -3883,7 +4184,7 @@ kernel void kernel_rope_neox(
|
|
|
3883
4184
|
|
|
3884
4185
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
|
3885
4186
|
|
|
3886
|
-
const float freq_factor = src2
|
|
4187
|
+
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
|
3887
4188
|
|
|
3888
4189
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
|
3889
4190
|
|
|
@@ -3941,20 +4242,32 @@ kernel void kernel_rope_multi(
|
|
|
3941
4242
|
const int sector = ic % sect_dims;
|
|
3942
4243
|
|
|
3943
4244
|
float theta_base;
|
|
3944
|
-
if (
|
|
3945
|
-
|
|
3946
|
-
|
|
3947
|
-
|
|
3948
|
-
|
|
3949
|
-
|
|
4245
|
+
if (FC_rope_is_imrope) {
|
|
4246
|
+
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
|
|
4247
|
+
theta_base = (float) pos[i2 + args.ne02 * 1];
|
|
4248
|
+
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
|
|
4249
|
+
theta_base = (float) pos[i2 + args.ne02 * 2];
|
|
4250
|
+
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
|
|
4251
|
+
theta_base = (float) pos[i2 + args.ne02 * 0];
|
|
4252
|
+
} else { // e
|
|
4253
|
+
theta_base = (float) pos[i2 + args.ne02 * 3];
|
|
4254
|
+
}
|
|
3950
4255
|
} else {
|
|
3951
|
-
|
|
4256
|
+
if (sector < args.sect_0) {
|
|
4257
|
+
theta_base = (float) pos[i2];
|
|
4258
|
+
} else if (sector < sec_w01) {
|
|
4259
|
+
theta_base = (float) pos[i2 + args.ne02 * 1];
|
|
4260
|
+
} else if (sector < sec_w012) {
|
|
4261
|
+
theta_base = (float) pos[i2 + args.ne02 * 2];
|
|
4262
|
+
} else {
|
|
4263
|
+
theta_base = (float) pos[i2 + args.ne02 * 3];
|
|
4264
|
+
}
|
|
3952
4265
|
}
|
|
3953
4266
|
// end of mrope
|
|
3954
4267
|
|
|
3955
4268
|
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
|
3956
4269
|
|
|
3957
|
-
const float freq_factor = src2
|
|
4270
|
+
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
|
3958
4271
|
|
|
3959
4272
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
|
3960
4273
|
|
|
@@ -4021,7 +4334,7 @@ kernel void kernel_rope_vision(
|
|
|
4021
4334
|
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
|
4022
4335
|
// end of mrope
|
|
4023
4336
|
|
|
4024
|
-
const float freq_factor = src2
|
|
4337
|
+
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
|
4025
4338
|
|
|
4026
4339
|
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
|
4027
4340
|
|
|
@@ -4178,6 +4491,120 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
|
4178
4491
|
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
|
4179
4492
|
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
|
4180
4493
|
|
|
4494
|
+
template <typename TK>
|
|
4495
|
+
kernel void kernel_conv_2d(
|
|
4496
|
+
constant ggml_metal_kargs_conv_2d & args,
|
|
4497
|
+
device const char * weights,
|
|
4498
|
+
device const char * src,
|
|
4499
|
+
device char * dst,
|
|
4500
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4501
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
4502
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4503
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
4504
|
+
|
|
4505
|
+
const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
|
|
4506
|
+
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
|
|
4507
|
+
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
|
|
4508
|
+
const uint thread_index = tg_index * threads_per_tg + local_thread;
|
|
4509
|
+
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
|
|
4510
|
+
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
|
|
4511
|
+
|
|
4512
|
+
for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
|
|
4513
|
+
uint64_t tmp = index;
|
|
4514
|
+
|
|
4515
|
+
const int32_t ow = tmp % args.OW; tmp /= args.OW;
|
|
4516
|
+
const int32_t oh = tmp % args.OH; tmp /= args.OH;
|
|
4517
|
+
const int32_t oc = tmp % args.OC; tmp /= args.OC;
|
|
4518
|
+
const int32_t n = tmp;
|
|
4519
|
+
|
|
4520
|
+
float acc = 0.0f;
|
|
4521
|
+
|
|
4522
|
+
const int32_t base_x = ow*args.s0 - args.p0;
|
|
4523
|
+
const int32_t base_y = oh*args.s1 - args.p1;
|
|
4524
|
+
|
|
4525
|
+
int32_t ky_start = 0;
|
|
4526
|
+
if (base_y < 0) {
|
|
4527
|
+
ky_start = (-base_y + args.d1 - 1)/args.d1;
|
|
4528
|
+
}
|
|
4529
|
+
int32_t ky_end = args.KH;
|
|
4530
|
+
const int32_t y_max = args.IH - 1 - base_y;
|
|
4531
|
+
if (y_max < 0) {
|
|
4532
|
+
ky_end = ky_start;
|
|
4533
|
+
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
|
|
4534
|
+
ky_end = min(ky_end, y_max/args.d1 + 1);
|
|
4535
|
+
}
|
|
4536
|
+
|
|
4537
|
+
int32_t kx_start = 0;
|
|
4538
|
+
if (base_x < 0) {
|
|
4539
|
+
kx_start = (-base_x + args.d0 - 1)/args.d0;
|
|
4540
|
+
}
|
|
4541
|
+
int32_t kx_end = args.KW;
|
|
4542
|
+
const int32_t x_max = args.IW - 1 - base_x;
|
|
4543
|
+
if (x_max < 0) {
|
|
4544
|
+
kx_end = kx_start;
|
|
4545
|
+
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
|
|
4546
|
+
kx_end = min(kx_end, x_max/args.d0 + 1);
|
|
4547
|
+
}
|
|
4548
|
+
|
|
4549
|
+
if (ky_start < ky_end && kx_start < kx_end) {
|
|
4550
|
+
const uint64_t src_base_n = (uint64_t) n * args.nb13;
|
|
4551
|
+
const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
|
|
4552
|
+
|
|
4553
|
+
for (int32_t ic = 0; ic < args.IC; ++ic) {
|
|
4554
|
+
const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
|
|
4555
|
+
const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
|
|
4556
|
+
|
|
4557
|
+
for (int32_t ky = ky_start; ky < ky_end; ++ky) {
|
|
4558
|
+
const int32_t iy = base_y + ky*args.d1;
|
|
4559
|
+
const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
|
|
4560
|
+
const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
|
|
4561
|
+
|
|
4562
|
+
for (int32_t kx = kx_start; kx < kx_end; ++kx) {
|
|
4563
|
+
const int32_t ix = base_x + kx*args.d0;
|
|
4564
|
+
const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
|
|
4565
|
+
const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
|
|
4566
|
+
|
|
4567
|
+
const float x = *(device const float *)(src + src_offs);
|
|
4568
|
+
const float w = (float) (*(device const TK *)(weights + w_offs));
|
|
4569
|
+
|
|
4570
|
+
acc += x * w;
|
|
4571
|
+
}
|
|
4572
|
+
}
|
|
4573
|
+
}
|
|
4574
|
+
}
|
|
4575
|
+
|
|
4576
|
+
const uint64_t dst_offs =
|
|
4577
|
+
(uint64_t) n * args.nb3 +
|
|
4578
|
+
(uint64_t) oc * args.nb2 +
|
|
4579
|
+
(uint64_t) oh * args.nb1 +
|
|
4580
|
+
(uint64_t) ow * args.nb0;
|
|
4581
|
+
|
|
4582
|
+
*(device float *)(dst + dst_offs) = acc;
|
|
4583
|
+
}
|
|
4584
|
+
}
|
|
4585
|
+
|
|
4586
|
+
template [[host_name("kernel_conv_2d_f32_f32")]]
|
|
4587
|
+
kernel void kernel_conv_2d<float>(
|
|
4588
|
+
constant ggml_metal_kargs_conv_2d & args,
|
|
4589
|
+
device const char * weights,
|
|
4590
|
+
device const char * src,
|
|
4591
|
+
device char * dst,
|
|
4592
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4593
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
4594
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4595
|
+
uint3 ntg[[threads_per_threadgroup]]);
|
|
4596
|
+
|
|
4597
|
+
template [[host_name("kernel_conv_2d_f16_f32")]]
|
|
4598
|
+
kernel void kernel_conv_2d<half>(
|
|
4599
|
+
constant ggml_metal_kargs_conv_2d & args,
|
|
4600
|
+
device const char * weights,
|
|
4601
|
+
device const char * src,
|
|
4602
|
+
device char * dst,
|
|
4603
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4604
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
4605
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4606
|
+
uint3 ntg[[threads_per_threadgroup]]);
|
|
4607
|
+
|
|
4181
4608
|
typedef void (conv_transpose_1d_t)(
|
|
4182
4609
|
constant ggml_metal_kargs_conv_transpose_1d & args,
|
|
4183
4610
|
device const float * src0,
|
|
@@ -4231,6 +4658,97 @@ kernel void kernel_conv_transpose_1d<half>(
|
|
|
4231
4658
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4232
4659
|
uint3 tgpg[[threadgroups_per_grid]]);
|
|
4233
4660
|
|
|
4661
|
+
|
|
4662
|
+
typedef void (conv_transpose_2d_t)(
|
|
4663
|
+
constant ggml_metal_kargs_conv_transpose_2d & args,
|
|
4664
|
+
device const float * src0,
|
|
4665
|
+
device const float * src1,
|
|
4666
|
+
device char * dst,
|
|
4667
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4668
|
+
uint3 tgpg[[threadgroups_per_grid]]);
|
|
4669
|
+
|
|
4670
|
+
template <typename T>
|
|
4671
|
+
kernel void kernel_conv_transpose_2d(
|
|
4672
|
+
constant ggml_metal_kargs_conv_transpose_2d & args,
|
|
4673
|
+
device const T * src0,
|
|
4674
|
+
device const float * src1,
|
|
4675
|
+
device char * dst,
|
|
4676
|
+
threadgroup float * shared_sum [[threadgroup(0)]],
|
|
4677
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4678
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4679
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
4680
|
+
|
|
4681
|
+
const int64_t out_x = tgpig[0];
|
|
4682
|
+
const int64_t out_y = tgpig[1];
|
|
4683
|
+
const int64_t out_c = tgpig[2];
|
|
4684
|
+
|
|
4685
|
+
const int64_t kw = tpitg[0];
|
|
4686
|
+
const int64_t kh = tpitg[1];
|
|
4687
|
+
|
|
4688
|
+
float v = 0.0f;
|
|
4689
|
+
|
|
4690
|
+
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
|
|
4691
|
+
int64_t in_y = out_y - kh;
|
|
4692
|
+
|
|
4693
|
+
if (in_y < 0 || in_y % args.s0) continue;
|
|
4694
|
+
|
|
4695
|
+
in_y /= args.s0;
|
|
4696
|
+
|
|
4697
|
+
if (in_y >= args.IH) continue;
|
|
4698
|
+
|
|
4699
|
+
int64_t in_x = out_x - kw;
|
|
4700
|
+
|
|
4701
|
+
if (in_x < 0 || in_x % args.s0) continue;
|
|
4702
|
+
|
|
4703
|
+
in_x /= args.s0;
|
|
4704
|
+
|
|
4705
|
+
if (in_x >= args.IW) continue;
|
|
4706
|
+
|
|
4707
|
+
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
|
|
4708
|
+
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
|
|
4709
|
+
|
|
4710
|
+
v += (float)src0[kernel_idx] * src1[input_idx];
|
|
4711
|
+
}
|
|
4712
|
+
|
|
4713
|
+
const uint tid = tpitg.y * ntg.x + tpitg.x;
|
|
4714
|
+
shared_sum[tid] = v;
|
|
4715
|
+
|
|
4716
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4717
|
+
|
|
4718
|
+
if (tid == 0) {
|
|
4719
|
+
float total = 0.0f;
|
|
4720
|
+
const uint num_threads = ntg.x * ntg.y;
|
|
4721
|
+
for (uint i = 0; i < num_threads; i++) {
|
|
4722
|
+
total += shared_sum[i];
|
|
4723
|
+
}
|
|
4724
|
+
|
|
4725
|
+
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
|
|
4726
|
+
dst_ptr[0] = total;
|
|
4727
|
+
}
|
|
4728
|
+
}
|
|
4729
|
+
|
|
4730
|
+
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
|
|
4731
|
+
kernel void kernel_conv_transpose_2d<float>(
|
|
4732
|
+
constant ggml_metal_kargs_conv_transpose_2d & args,
|
|
4733
|
+
device const float * src0,
|
|
4734
|
+
device const float * src1,
|
|
4735
|
+
device char * dst,
|
|
4736
|
+
threadgroup float * shared_sum [[threadgroup(0)]],
|
|
4737
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4738
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4739
|
+
uint3 ntg[[threads_per_threadgroup]]);
|
|
4740
|
+
|
|
4741
|
+
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
|
|
4742
|
+
kernel void kernel_conv_transpose_2d<half>(
|
|
4743
|
+
constant ggml_metal_kargs_conv_transpose_2d & args,
|
|
4744
|
+
device const half * src0,
|
|
4745
|
+
device const float * src1,
|
|
4746
|
+
device char * dst,
|
|
4747
|
+
threadgroup float * shared_sum [[threadgroup(0)]],
|
|
4748
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4749
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4750
|
+
uint3 ntg[[threads_per_threadgroup]]);
|
|
4751
|
+
|
|
4234
4752
|
kernel void kernel_upscale_f32(
|
|
4235
4753
|
constant ggml_metal_kargs_upscale & args,
|
|
4236
4754
|
device const char * src0,
|
|
@@ -4368,69 +4886,234 @@ kernel void kernel_timestep_embedding_f32(
|
|
|
4368
4886
|
// bitonic sort implementation following the CUDA kernels as reference
|
|
4369
4887
|
typedef void (argsort_t)(
|
|
4370
4888
|
constant ggml_metal_kargs_argsort & args,
|
|
4371
|
-
device
|
|
4889
|
+
device const char * src0,
|
|
4372
4890
|
device int32_t * dst,
|
|
4373
|
-
threadgroup int32_t *
|
|
4374
|
-
uint3
|
|
4375
|
-
|
|
4891
|
+
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
|
4892
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4893
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
4894
|
+
ushort3 ntg[[threads_per_threadgroup]]);
|
|
4376
4895
|
|
|
4377
4896
|
template<ggml_sort_order order>
|
|
4378
4897
|
kernel void kernel_argsort_f32_i32(
|
|
4379
4898
|
constant ggml_metal_kargs_argsort & args,
|
|
4380
|
-
device
|
|
4899
|
+
device const char * src0,
|
|
4381
4900
|
device int32_t * dst,
|
|
4382
|
-
threadgroup int32_t *
|
|
4383
|
-
uint3
|
|
4384
|
-
|
|
4901
|
+
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
|
4902
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4903
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
4904
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
4385
4905
|
// bitonic sort
|
|
4386
|
-
int col = tpitg[0];
|
|
4387
|
-
int
|
|
4906
|
+
const int col = tpitg[0];
|
|
4907
|
+
const int ib = tgpig[0] / args.ne01;
|
|
4388
4908
|
|
|
4389
|
-
|
|
4909
|
+
const int i00 = ib*ntg.x;
|
|
4910
|
+
const int i01 = tgpig[0] % args.ne01;
|
|
4911
|
+
const int i02 = tgpig[1];
|
|
4912
|
+
const int i03 = tgpig[2];
|
|
4390
4913
|
|
|
4391
|
-
device const float
|
|
4392
|
-
threadgroup int32_t * dst_row = shared_values;
|
|
4914
|
+
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
|
4393
4915
|
|
|
4394
4916
|
// initialize indices
|
|
4395
|
-
|
|
4917
|
+
shmem_i32[col] = i00 + col;
|
|
4396
4918
|
|
|
4397
4919
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4398
4920
|
|
|
4399
|
-
for (int k = 2; k <=
|
|
4921
|
+
for (int k = 2; k <= ntg.x; k *= 2) {
|
|
4400
4922
|
for (int j = k / 2; j > 0; j /= 2) {
|
|
4401
4923
|
int ixj = col ^ j;
|
|
4402
4924
|
if (ixj > col) {
|
|
4403
4925
|
if ((col & k) == 0) {
|
|
4404
|
-
if (
|
|
4405
|
-
|
|
4406
|
-
|
|
4407
|
-
|
|
4926
|
+
if (shmem_i32[col] >= args.ne00 ||
|
|
4927
|
+
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
|
4928
|
+
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
|
|
4929
|
+
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
|
|
4408
4930
|
) {
|
|
4409
|
-
SWAP(
|
|
4931
|
+
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
|
4410
4932
|
}
|
|
4411
4933
|
} else {
|
|
4412
|
-
if (
|
|
4413
|
-
|
|
4414
|
-
|
|
4415
|
-
|
|
4934
|
+
if (shmem_i32[ixj] >= args.ne00 ||
|
|
4935
|
+
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
|
4936
|
+
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
|
|
4937
|
+
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
|
|
4416
4938
|
) {
|
|
4417
|
-
SWAP(
|
|
4939
|
+
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
|
4418
4940
|
}
|
|
4419
4941
|
}
|
|
4420
4942
|
}
|
|
4943
|
+
|
|
4421
4944
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4422
4945
|
}
|
|
4423
4946
|
}
|
|
4424
4947
|
|
|
4948
|
+
const int64_t i0 = ib*args.top_k;
|
|
4949
|
+
|
|
4425
4950
|
// copy the result to dst without the padding
|
|
4426
|
-
if (col < args.
|
|
4427
|
-
dst
|
|
4951
|
+
if (i0 + col < args.ne0 && col < args.top_k) {
|
|
4952
|
+
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
|
|
4953
|
+
|
|
4954
|
+
dst[col] = shmem_i32[col];
|
|
4428
4955
|
}
|
|
4429
4956
|
}
|
|
4430
4957
|
|
|
4431
4958
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
4432
4959
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
4433
4960
|
|
|
4961
|
+
typedef void (argsort_merge_t)(
|
|
4962
|
+
constant ggml_metal_kargs_argsort_merge & args,
|
|
4963
|
+
device const char * src0,
|
|
4964
|
+
device const int32_t * tmp,
|
|
4965
|
+
device int32_t * dst,
|
|
4966
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4967
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
4968
|
+
ushort3 ntg[[threads_per_threadgroup]]);
|
|
4969
|
+
|
|
4970
|
+
template<ggml_sort_order order>
|
|
4971
|
+
kernel void kernel_argsort_merge_f32_i32(
|
|
4972
|
+
constant ggml_metal_kargs_argsort_merge & args,
|
|
4973
|
+
device const char * src0,
|
|
4974
|
+
device const int32_t * tmp,
|
|
4975
|
+
device int32_t * dst,
|
|
4976
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4977
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
4978
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
4979
|
+
|
|
4980
|
+
const int im = tgpig[0] / args.ne01;
|
|
4981
|
+
const int i01 = tgpig[0] % args.ne01;
|
|
4982
|
+
const int i02 = tgpig[1];
|
|
4983
|
+
const int i03 = tgpig[2];
|
|
4984
|
+
|
|
4985
|
+
const int start = im * (2 * args.len);
|
|
4986
|
+
|
|
4987
|
+
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
|
4988
|
+
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
|
4989
|
+
|
|
4990
|
+
const int total = len0 + len1;
|
|
4991
|
+
|
|
4992
|
+
device const int32_t * tmp0 = tmp + start
|
|
4993
|
+
+ i01*args.ne0
|
|
4994
|
+
+ i02*args.ne0*args.ne01
|
|
4995
|
+
+ i03*args.ne0*args.ne01*args.ne02;
|
|
4996
|
+
|
|
4997
|
+
device const int32_t * tmp1 = tmp0 + args.len;
|
|
4998
|
+
|
|
4999
|
+
dst += start
|
|
5000
|
+
+ i01*args.top_k
|
|
5001
|
+
+ i02*args.top_k*args.ne01
|
|
5002
|
+
+ i03*args.top_k*args.ne01*args.ne02;
|
|
5003
|
+
|
|
5004
|
+
device const float * src0_row = (device const float *)(src0
|
|
5005
|
+
+ args.nb01*i01
|
|
5006
|
+
+ args.nb02*i02
|
|
5007
|
+
+ args.nb03*i03);
|
|
5008
|
+
|
|
5009
|
+
if (total == 0) {
|
|
5010
|
+
return;
|
|
5011
|
+
}
|
|
5012
|
+
|
|
5013
|
+
const int chunk = (total + ntg.x - 1) / ntg.x;
|
|
5014
|
+
|
|
5015
|
+
const int k0 = tpitg.x * chunk;
|
|
5016
|
+
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
|
5017
|
+
|
|
5018
|
+
if (k0 >= args.top_k) {
|
|
5019
|
+
return;
|
|
5020
|
+
}
|
|
5021
|
+
|
|
5022
|
+
if (k0 >= total) {
|
|
5023
|
+
return;
|
|
5024
|
+
}
|
|
5025
|
+
|
|
5026
|
+
int low = k0 > len1 ? k0 - len1 : 0;
|
|
5027
|
+
int high = MIN(k0, len0);
|
|
5028
|
+
|
|
5029
|
+
// binary-search partition (i, j) such that i + j = k
|
|
5030
|
+
while (low < high) {
|
|
5031
|
+
const int mid = (low + high) >> 1;
|
|
5032
|
+
|
|
5033
|
+
const int32_t idx0 = tmp0[mid];
|
|
5034
|
+
const int32_t idx1 = tmp1[k0 - mid - 1];
|
|
5035
|
+
|
|
5036
|
+
const float val0 = src0_row[idx0];
|
|
5037
|
+
const float val1 = src0_row[idx1];
|
|
5038
|
+
|
|
5039
|
+
bool take_left;
|
|
5040
|
+
if (order == GGML_SORT_ORDER_ASC) {
|
|
5041
|
+
take_left = (val0 <= val1);
|
|
5042
|
+
} else {
|
|
5043
|
+
take_left = (val0 >= val1);
|
|
5044
|
+
}
|
|
5045
|
+
|
|
5046
|
+
if (take_left) {
|
|
5047
|
+
low = mid + 1;
|
|
5048
|
+
} else {
|
|
5049
|
+
high = mid;
|
|
5050
|
+
}
|
|
5051
|
+
}
|
|
5052
|
+
|
|
5053
|
+
int i = low;
|
|
5054
|
+
int j = k0 - i;
|
|
5055
|
+
|
|
5056
|
+
// keep the merge fronts into registers
|
|
5057
|
+
int32_t idx0 = 0;
|
|
5058
|
+
float val0 = 0.0f;
|
|
5059
|
+
if (i < len0) {
|
|
5060
|
+
idx0 = tmp0[i];
|
|
5061
|
+
val0 = src0_row[idx0];
|
|
5062
|
+
}
|
|
5063
|
+
|
|
5064
|
+
int32_t idx1 = 0;
|
|
5065
|
+
float val1 = 0.0f;
|
|
5066
|
+
if (j < len1) {
|
|
5067
|
+
idx1 = tmp1[j];
|
|
5068
|
+
val1 = src0_row[idx1];
|
|
5069
|
+
}
|
|
5070
|
+
|
|
5071
|
+
for (int k = k0; k < k1; ++k) {
|
|
5072
|
+
int32_t out_idx;
|
|
5073
|
+
|
|
5074
|
+
if (i >= len0) {
|
|
5075
|
+
while (k < k1) {
|
|
5076
|
+
dst[k++] = tmp1[j++];
|
|
5077
|
+
}
|
|
5078
|
+
break;
|
|
5079
|
+
} else if (j >= len1) {
|
|
5080
|
+
while (k < k1) {
|
|
5081
|
+
dst[k++] = tmp0[i++];
|
|
5082
|
+
}
|
|
5083
|
+
break;
|
|
5084
|
+
} else {
|
|
5085
|
+
bool take_left;
|
|
5086
|
+
|
|
5087
|
+
if (order == GGML_SORT_ORDER_ASC) {
|
|
5088
|
+
take_left = (val0 <= val1);
|
|
5089
|
+
} else {
|
|
5090
|
+
take_left = (val0 >= val1);
|
|
5091
|
+
}
|
|
5092
|
+
|
|
5093
|
+
if (take_left) {
|
|
5094
|
+
out_idx = idx0;
|
|
5095
|
+
++i;
|
|
5096
|
+
if (i < len0) {
|
|
5097
|
+
idx0 = tmp0[i];
|
|
5098
|
+
val0 = src0_row[idx0];
|
|
5099
|
+
}
|
|
5100
|
+
} else {
|
|
5101
|
+
out_idx = idx1;
|
|
5102
|
+
++j;
|
|
5103
|
+
if (j < len1) {
|
|
5104
|
+
idx1 = tmp1[j];
|
|
5105
|
+
val1 = src0_row[idx1];
|
|
5106
|
+
}
|
|
5107
|
+
}
|
|
5108
|
+
}
|
|
5109
|
+
|
|
5110
|
+
dst[k] = out_idx;
|
|
5111
|
+
}
|
|
5112
|
+
}
|
|
5113
|
+
|
|
5114
|
+
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
5115
|
+
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
5116
|
+
|
|
4434
5117
|
kernel void kernel_leaky_relu_f32(
|
|
4435
5118
|
constant ggml_metal_kargs_leaky_relu & args,
|
|
4436
5119
|
device const float * src0,
|
|
@@ -4449,10 +5132,142 @@ kernel void kernel_leaky_relu_f32_4(
|
|
|
4449
5132
|
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
|
4450
5133
|
}
|
|
4451
5134
|
|
|
5135
|
+
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
|
5136
|
+
|
|
5137
|
+
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
|
|
5138
|
+
|
|
5139
|
+
// pad the last chunk of C elements of k and v into a an extra pad buffer
|
|
5140
|
+
kernel void kernel_flash_attn_ext_pad(
|
|
5141
|
+
constant ggml_metal_kargs_flash_attn_ext_pad & args,
|
|
5142
|
+
device const char * k,
|
|
5143
|
+
device const char * v,
|
|
5144
|
+
device const char * mask,
|
|
5145
|
+
device char * dst,
|
|
5146
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5147
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
5148
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5149
|
+
const int32_t C = FC_flash_attn_ext_pad_ncpsg;
|
|
5150
|
+
|
|
5151
|
+
device char * k_pad = dst;
|
|
5152
|
+
device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
|
5153
|
+
device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
|
5154
|
+
|
|
5155
|
+
const int32_t icp = args.ne11 % C;
|
|
5156
|
+
const int32_t ic0 = args.ne11 - icp;
|
|
5157
|
+
|
|
5158
|
+
const int32_t i1 = tgpig[0];
|
|
5159
|
+
const int32_t i2 = tgpig[1];
|
|
5160
|
+
const int32_t i3 = tgpig[2];
|
|
5161
|
+
|
|
5162
|
+
if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
|
|
5163
|
+
device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
|
|
5164
|
+
device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
|
|
5165
|
+
|
|
5166
|
+
device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
|
|
5167
|
+
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
|
|
5168
|
+
|
|
5169
|
+
if (i1 >= icp) {
|
|
5170
|
+
// here it is not important the exact value that will be used as we rely on masking out the scores in the attention
|
|
5171
|
+
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
|
|
5172
|
+
k_dst[i] = 0;
|
|
5173
|
+
}
|
|
5174
|
+
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
|
|
5175
|
+
v_dst[i] = 0;
|
|
5176
|
+
}
|
|
5177
|
+
} else {
|
|
5178
|
+
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
|
|
5179
|
+
k_dst[i] = k_src[i];
|
|
5180
|
+
}
|
|
5181
|
+
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
|
|
5182
|
+
v_dst[i] = v_src[i];
|
|
5183
|
+
}
|
|
5184
|
+
}
|
|
5185
|
+
}
|
|
5186
|
+
|
|
5187
|
+
if (FC_flash_attn_ext_pad_has_mask) {
|
|
5188
|
+
if (i2 < args.ne32 && i3 < args.ne33) {
|
|
5189
|
+
for (int ib = i1; ib < args.ne31; ib += C) {
|
|
5190
|
+
device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
|
|
5191
|
+
device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
|
|
5192
|
+
|
|
5193
|
+
for (int i = tiitg; i < C; i += ntg.x) {
|
|
5194
|
+
if (i >= icp) {
|
|
5195
|
+
mask_dst[i] = -MAXHALF;
|
|
5196
|
+
} else {
|
|
5197
|
+
mask_dst[i] = mask_src[i];
|
|
5198
|
+
}
|
|
5199
|
+
}
|
|
5200
|
+
}
|
|
5201
|
+
}
|
|
5202
|
+
}
|
|
5203
|
+
}
|
|
5204
|
+
|
|
5205
|
+
constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
|
|
5206
|
+
constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
|
|
5207
|
+
|
|
5208
|
+
// scan the blocks of the mask that are not masked
|
|
5209
|
+
// 0 - masked (i.e. full of -INF, skip)
|
|
5210
|
+
// 1 - not masked (i.e. at least one element of the mask is not -INF)
|
|
5211
|
+
kernel void kernel_flash_attn_ext_blk(
|
|
5212
|
+
constant ggml_metal_kargs_flash_attn_ext_blk & args,
|
|
5213
|
+
device const char * mask,
|
|
5214
|
+
device char * dst,
|
|
5215
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5216
|
+
ushort tiisg[[thread_index_in_simdgroup]]) {
|
|
5217
|
+
// block size C x Q
|
|
5218
|
+
const int32_t Q = FC_flash_attn_ext_blk_nqptg;
|
|
5219
|
+
const int32_t C = FC_flash_attn_ext_blk_ncpsg;
|
|
5220
|
+
|
|
5221
|
+
constexpr short NW = N_SIMDWIDTH;
|
|
5222
|
+
|
|
5223
|
+
const int32_t i3 = tgpig[2]/args.ne32;
|
|
5224
|
+
const int32_t i2 = tgpig[2]%args.ne32;
|
|
5225
|
+
const int32_t i1 = tgpig[1];
|
|
5226
|
+
const int32_t i0 = tgpig[0];
|
|
5227
|
+
|
|
5228
|
+
char res = i0*C + C > args.ne30 ? 1 : 0;
|
|
5229
|
+
|
|
5230
|
+
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
|
|
5231
|
+
|
|
5232
|
+
// fast route
|
|
5233
|
+
if (res == 0) {
|
|
5234
|
+
if (simd_max(*mask_src) > -MAXHALF/2) {
|
|
5235
|
+
res = 1;
|
|
5236
|
+
}
|
|
5237
|
+
}
|
|
5238
|
+
|
|
5239
|
+
// detailed check of the elements of the block
|
|
5240
|
+
if ((C > NW || Q > 1) && res == 0) {
|
|
5241
|
+
half m = -MAXHALF;
|
|
5242
|
+
|
|
5243
|
+
FOR_UNROLL (short j = 0; j < Q; ++j) {
|
|
5244
|
+
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
|
|
5245
|
+
m = max(m, mask_src[ii*NW]);
|
|
5246
|
+
}
|
|
5247
|
+
|
|
5248
|
+
mask_src += args.nb31/2;
|
|
5249
|
+
}
|
|
5250
|
+
|
|
5251
|
+
if (simd_max(m) > -MAXHALF/2) {
|
|
5252
|
+
res = 1;
|
|
5253
|
+
}
|
|
5254
|
+
}
|
|
5255
|
+
|
|
5256
|
+
const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
|
|
5257
|
+
const int32_t nblk0 = ((args.ne30 + C - 1)/C);
|
|
5258
|
+
|
|
5259
|
+
if (tiisg == 0) {
|
|
5260
|
+
dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
|
|
5261
|
+
}
|
|
5262
|
+
}
|
|
5263
|
+
|
|
4452
5264
|
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
|
|
4453
5265
|
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
|
|
4454
5266
|
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
|
|
4455
5267
|
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
|
|
5268
|
+
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
|
|
5269
|
+
|
|
5270
|
+
constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
|
4456
5271
|
|
|
4457
5272
|
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
|
4458
5273
|
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
|
|
@@ -4499,6 +5314,8 @@ void kernel_flash_attn_ext_impl(
|
|
|
4499
5314
|
device const char * v,
|
|
4500
5315
|
device const char * mask,
|
|
4501
5316
|
device const char * sinks,
|
|
5317
|
+
device const char * pad,
|
|
5318
|
+
device const char * blk,
|
|
4502
5319
|
device char * dst,
|
|
4503
5320
|
threadgroup half * shmem_f16,
|
|
4504
5321
|
uint3 tgpig,
|
|
@@ -4564,6 +5381,13 @@ void kernel_flash_attn_ext_impl(
|
|
|
4564
5381
|
pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
|
4565
5382
|
}
|
|
4566
5383
|
|
|
5384
|
+
{
|
|
5385
|
+
const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
|
|
5386
|
+
const int32_t nblk0 = ((args.ne11 + C - 1)/C);
|
|
5387
|
+
|
|
5388
|
+
blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
|
|
5389
|
+
}
|
|
5390
|
+
|
|
4567
5391
|
{
|
|
4568
5392
|
q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
|
|
4569
5393
|
|
|
@@ -4623,16 +5447,75 @@ void kernel_flash_attn_ext_impl(
|
|
|
4623
5447
|
|
|
4624
5448
|
// loop over the KV cache
|
|
4625
5449
|
// each simdgroup handles blocks of Q rows and C columns
|
|
4626
|
-
for (int
|
|
4627
|
-
|
|
4628
|
-
if (
|
|
4629
|
-
|
|
4630
|
-
|
|
4631
|
-
|
|
4632
|
-
|
|
5450
|
+
for (int ic0 = 0; ; ++ic0) {
|
|
5451
|
+
int ic = ic0*C;
|
|
5452
|
+
if (ic >= args.ne11) {
|
|
5453
|
+
break;
|
|
5454
|
+
}
|
|
5455
|
+
|
|
5456
|
+
// the last partial chunk uses the pad buffer as source
|
|
5457
|
+
if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
|
|
5458
|
+
k = pad;
|
|
5459
|
+
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
|
5460
|
+
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
|
5461
|
+
|
|
5462
|
+
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
|
5463
|
+
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
|
5464
|
+
|
|
5465
|
+
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
|
|
5466
|
+
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
|
5467
|
+
|
|
5468
|
+
if (!FC_flash_attn_ext_has_mask) {
|
|
5469
|
+
threadgroup half * sm = (threadgroup half *) (sm2);
|
|
5470
|
+
|
|
5471
|
+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
5472
|
+
const short j = jj*NSG + sgitg;
|
|
5473
|
+
|
|
5474
|
+
for (short i = tiisg; i < C; i += NW) {
|
|
5475
|
+
if (ic + i >= args.ne11) {
|
|
5476
|
+
sm[2*j*SH + i] = -MAXHALF;
|
|
5477
|
+
}
|
|
5478
|
+
}
|
|
5479
|
+
}
|
|
5480
|
+
} else {
|
|
5481
|
+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
5482
|
+
const short j = jj*NSG + sgitg;
|
|
5483
|
+
|
|
5484
|
+
pm2[jj] = (device const half2 *) ((device const half *) mask +
|
|
5485
|
+
(iq1 + j)*C +
|
|
5486
|
+
(iq2%args.ne32)*(C*args.ne31) +
|
|
5487
|
+
(iq3%args.ne33)*(C*args.ne31*args.ne32));
|
|
5488
|
+
}
|
|
5489
|
+
}
|
|
5490
|
+
|
|
5491
|
+
ic = 0;
|
|
5492
|
+
}
|
|
5493
|
+
|
|
5494
|
+
// read the mask into shared mem
|
|
5495
|
+
if (FC_flash_attn_ext_has_mask) {
|
|
5496
|
+
if (blk[ic0] == 0) {
|
|
5497
|
+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
5498
|
+
pm2[jj] += NW;
|
|
5499
|
+
}
|
|
5500
|
+
|
|
5501
|
+
continue;
|
|
5502
|
+
}
|
|
5503
|
+
|
|
5504
|
+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
5505
|
+
const short j = jj*NSG + sgitg;
|
|
5506
|
+
|
|
5507
|
+
if (FC_flash_attn_ext_bc_mask) {
|
|
5508
|
+
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
|
|
5509
|
+
} else {
|
|
5510
|
+
sm2[j*SH + tiisg] = pm2[jj][tiisg];
|
|
5511
|
+
}
|
|
5512
|
+
|
|
4633
5513
|
pm2[jj] += NW;
|
|
4634
5514
|
}
|
|
4635
5515
|
|
|
5516
|
+
#if 0
|
|
5517
|
+
// note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
|
|
5518
|
+
|
|
4636
5519
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4637
5520
|
|
|
4638
5521
|
// used to detect blocks full of -INF
|
|
@@ -4651,13 +5534,14 @@ void kernel_flash_attn_ext_impl(
|
|
|
4651
5534
|
|
|
4652
5535
|
continue;
|
|
4653
5536
|
}
|
|
5537
|
+
#endif
|
|
4654
5538
|
}
|
|
4655
5539
|
|
|
4656
5540
|
// Q*K^T
|
|
4657
5541
|
// this is compile-time check, so it does not have runtime overhead
|
|
4658
5542
|
if (is_same<kd4x4_t, k4x4_t>::value) {
|
|
4659
5543
|
// we can read directly from global memory
|
|
4660
|
-
device const k_t * pk = (device const k_t *) (
|
|
5544
|
+
device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
|
|
4661
5545
|
threadgroup const q_t * pq = sq;
|
|
4662
5546
|
threadgroup s_t * ps = ss;
|
|
4663
5547
|
|
|
@@ -4668,26 +5552,24 @@ void kernel_flash_attn_ext_impl(
|
|
|
4668
5552
|
|
|
4669
5553
|
constexpr short NC = (C/8)/NSG;
|
|
4670
5554
|
|
|
4671
|
-
//
|
|
5555
|
+
// note: do not unroll for large heads
|
|
5556
|
+
#pragma unroll (DK <= 64 ? NC : 1)
|
|
4672
5557
|
for (short cc = 0; cc < NC; ++cc) {
|
|
4673
5558
|
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
|
4674
5559
|
|
|
4675
|
-
if (
|
|
5560
|
+
if (DK % 16 != 0) {
|
|
4676
5561
|
k8x8_t mk;
|
|
4677
5562
|
q8x8_t mq;
|
|
4678
5563
|
|
|
4679
5564
|
FOR_UNROLL (short i = 0; i < DK8; ++i) {
|
|
4680
5565
|
simdgroup_barrier(mem_flags::mem_none);
|
|
4681
5566
|
|
|
4682
|
-
simdgroup_load(mk, pk, NS10, 0, true);
|
|
4683
|
-
simdgroup_load(mq, pq, DK);
|
|
5567
|
+
simdgroup_load(mk, pk + 8*i, NS10, 0, true);
|
|
5568
|
+
simdgroup_load(mq, pq + 8*i, DK);
|
|
4684
5569
|
|
|
4685
5570
|
simdgroup_barrier(mem_flags::mem_none);
|
|
4686
5571
|
|
|
4687
5572
|
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
|
4688
|
-
|
|
4689
|
-
pk += 8;
|
|
4690
|
-
pq += 8;
|
|
4691
5573
|
}
|
|
4692
5574
|
} else {
|
|
4693
5575
|
k8x8_t mk[2];
|
|
@@ -4696,26 +5578,22 @@ void kernel_flash_attn_ext_impl(
|
|
|
4696
5578
|
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
|
|
4697
5579
|
simdgroup_barrier(mem_flags::mem_none);
|
|
4698
5580
|
|
|
4699
|
-
simdgroup_load(
|
|
4700
|
-
simdgroup_load(
|
|
5581
|
+
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
|
|
5582
|
+
simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
|
|
4701
5583
|
|
|
4702
|
-
simdgroup_load(
|
|
4703
|
-
simdgroup_load(
|
|
5584
|
+
simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
|
|
5585
|
+
simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
|
|
4704
5586
|
|
|
4705
5587
|
simdgroup_barrier(mem_flags::mem_none);
|
|
4706
5588
|
|
|
4707
5589
|
simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
|
|
4708
5590
|
simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
|
|
4709
|
-
|
|
4710
|
-
pk += 16;
|
|
4711
|
-
pq += 16;
|
|
4712
5591
|
}
|
|
4713
5592
|
}
|
|
4714
5593
|
|
|
4715
5594
|
simdgroup_store(mqk, ps, SH, 0, false);
|
|
4716
5595
|
|
|
4717
|
-
pk += 8*(NSG*NS10
|
|
4718
|
-
pq += 8*(NSG*0 - DK8);
|
|
5596
|
+
pk += 8*(NSG*NS10);
|
|
4719
5597
|
ps += 8*(NSG);
|
|
4720
5598
|
}
|
|
4721
5599
|
} else {
|
|
@@ -4729,7 +5607,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
4729
5607
|
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
|
4730
5608
|
|
|
4731
5609
|
for (short ii = 0; ii < DK16; ii += 4) {
|
|
4732
|
-
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (
|
|
5610
|
+
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
|
|
4733
5611
|
|
|
4734
5612
|
if (DK16%4 == 0) {
|
|
4735
5613
|
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
|
@@ -4849,27 +5727,50 @@ void kernel_flash_attn_ext_impl(
|
|
|
4849
5727
|
}
|
|
4850
5728
|
|
|
4851
5729
|
{
|
|
4852
|
-
|
|
4853
|
-
|
|
4854
|
-
device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
|
|
5730
|
+
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
|
|
4855
5731
|
|
|
4856
5732
|
pv += 8*sgitg;
|
|
4857
5733
|
|
|
4858
|
-
|
|
4859
|
-
|
|
4860
|
-
|
|
5734
|
+
if (DV <= 64) {
|
|
5735
|
+
FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
|
|
5736
|
+
s8x8_t vs;
|
|
5737
|
+
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
|
|
4861
5738
|
|
|
4862
|
-
|
|
4863
|
-
|
|
5739
|
+
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
|
|
5740
|
+
v8x8_t mv[2];
|
|
4864
5741
|
|
|
4865
|
-
|
|
4866
|
-
|
|
5742
|
+
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
|
|
5743
|
+
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
|
|
4867
5744
|
|
|
4868
|
-
|
|
5745
|
+
simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
|
|
5746
|
+
simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
|
|
5747
|
+
}
|
|
5748
|
+
|
|
5749
|
+
pv += 8*NS20;
|
|
4869
5750
|
}
|
|
5751
|
+
} else {
|
|
5752
|
+
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
|
|
5753
|
+
s8x8_t vs[2];
|
|
5754
|
+
|
|
5755
|
+
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
|
|
5756
|
+
simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
|
|
4870
5757
|
|
|
4871
|
-
|
|
4872
|
-
|
|
5758
|
+
FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
|
|
5759
|
+
v8x8_t mv[4];
|
|
5760
|
+
|
|
5761
|
+
simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
|
|
5762
|
+
simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
|
|
5763
|
+
simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
|
|
5764
|
+
simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
|
|
5765
|
+
|
|
5766
|
+
simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
|
|
5767
|
+
simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
|
|
5768
|
+
simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
|
|
5769
|
+
simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
|
|
5770
|
+
}
|
|
5771
|
+
|
|
5772
|
+
pv += 2*8*NS20;
|
|
5773
|
+
}
|
|
4873
5774
|
}
|
|
4874
5775
|
}
|
|
4875
5776
|
|
|
@@ -4893,7 +5794,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
4893
5794
|
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
|
|
4894
5795
|
|
|
4895
5796
|
for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
|
|
4896
|
-
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (
|
|
5797
|
+
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
|
|
4897
5798
|
|
|
4898
5799
|
if (DV16%4 == 0) {
|
|
4899
5800
|
// no need for bound checks
|
|
@@ -4983,7 +5884,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
4983
5884
|
|
|
4984
5885
|
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
|
|
4985
5886
|
|
|
4986
|
-
const float scale = 1.0f/S[jj];
|
|
5887
|
+
const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
|
|
4987
5888
|
|
|
4988
5889
|
if (DV4 % NW == 0) {
|
|
4989
5890
|
FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
|
|
@@ -5028,8 +5929,8 @@ template<
|
|
|
5028
5929
|
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
|
5029
5930
|
short DK, // K head size
|
|
5030
5931
|
short DV, // V head size
|
|
5031
|
-
short Q =
|
|
5032
|
-
short C =
|
|
5932
|
+
short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
|
|
5933
|
+
short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
|
|
5033
5934
|
kernel void kernel_flash_attn_ext(
|
|
5034
5935
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
|
5035
5936
|
device const char * q,
|
|
@@ -5037,13 +5938,15 @@ kernel void kernel_flash_attn_ext(
|
|
|
5037
5938
|
device const char * v,
|
|
5038
5939
|
device const char * mask,
|
|
5039
5940
|
device const char * sinks,
|
|
5941
|
+
device const char * pad,
|
|
5942
|
+
device const char * blk,
|
|
5040
5943
|
device char * dst,
|
|
5041
5944
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
|
5042
5945
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5043
5946
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
5044
5947
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5045
5948
|
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
|
|
5046
|
-
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
|
|
5949
|
+
#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
|
|
5047
5950
|
switch (FC_flash_attn_ext_nsg) {
|
|
5048
5951
|
// note: disabled cases to reduce library load time
|
|
5049
5952
|
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
|
@@ -5075,10 +5978,36 @@ kernel void kernel_flash_attn_ext(
|
|
|
5075
5978
|
half, half4, simdgroup_half8x8
|
|
5076
5979
|
//float, float4, simdgroup_float8x8
|
|
5077
5980
|
|
|
5981
|
+
#define FA_TYPES_F32 \
|
|
5982
|
+
half, half4, simdgroup_half8x8, \
|
|
5983
|
+
float, float4x4, simdgroup_float8x8, \
|
|
5984
|
+
float, float4x4, simdgroup_float8x8, \
|
|
5985
|
+
float, simdgroup_float8x8, \
|
|
5986
|
+
float, float2, simdgroup_float8x8, \
|
|
5987
|
+
float, float4, simdgroup_float8x8
|
|
5988
|
+
//half, half4, simdgroup_half8x8
|
|
5989
|
+
|
|
5078
5990
|
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
|
5079
5991
|
|
|
5992
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
|
5993
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
|
5994
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 48, 48>;
|
|
5995
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
|
5996
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
|
|
5997
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
|
5998
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
|
5999
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
|
6000
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
|
|
6001
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
|
|
6002
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
|
|
6003
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
|
|
6004
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
|
|
6005
|
+
|
|
6006
|
+
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
|
5080
6007
|
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
|
6008
|
+
template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 48, 48>;
|
|
5081
6009
|
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
|
6010
|
+
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
|
|
5082
6011
|
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
|
5083
6012
|
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
|
5084
6013
|
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
|
@@ -5089,8 +6018,11 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
|
|
|
5089
6018
|
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
|
5090
6019
|
|
|
5091
6020
|
#if defined(GGML_METAL_HAS_BF16)
|
|
6021
|
+
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
|
5092
6022
|
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
|
6023
|
+
template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 48, 48>;
|
|
5093
6024
|
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
|
6025
|
+
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
|
|
5094
6026
|
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
|
5095
6027
|
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
|
5096
6028
|
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
|
@@ -5101,8 +6033,11 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
|
|
|
5101
6033
|
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
|
5102
6034
|
#endif
|
|
5103
6035
|
|
|
6036
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
|
5104
6037
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
|
6038
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 48, 48>;
|
|
5105
6039
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
|
6040
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
|
|
5106
6041
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
|
5107
6042
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
|
5108
6043
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
|
@@ -5112,8 +6047,11 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
|
|
|
5112
6047
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
|
5113
6048
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
|
5114
6049
|
|
|
6050
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
|
5115
6051
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
|
6052
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 48, 48>;
|
|
5116
6053
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
|
6054
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
|
|
5117
6055
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
|
5118
6056
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
|
5119
6057
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
|
@@ -5123,8 +6061,11 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
|
|
|
5123
6061
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
|
5124
6062
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
|
5125
6063
|
|
|
6064
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
|
5126
6065
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
|
6066
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 48, 48>;
|
|
5127
6067
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
|
6068
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
|
|
5128
6069
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
|
5129
6070
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
|
5130
6071
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
|
@@ -5134,8 +6075,11 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
|
|
|
5134
6075
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
|
5135
6076
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
|
5136
6077
|
|
|
6078
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
|
5137
6079
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
|
6080
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 48, 48>;
|
|
5138
6081
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
|
6082
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
|
|
5139
6083
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
|
5140
6084
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
|
5141
6085
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
|
@@ -5145,8 +6089,11 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
|
|
|
5145
6089
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
|
5146
6090
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
|
5147
6091
|
|
|
6092
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
|
5148
6093
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
|
6094
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 48, 48>;
|
|
5149
6095
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
|
6096
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
|
|
5150
6097
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
|
5151
6098
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
|
5152
6099
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
|
@@ -5158,11 +6105,13 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_at
|
|
|
5158
6105
|
|
|
5159
6106
|
#undef FA_TYPES
|
|
5160
6107
|
#undef FA_TYPES_BF
|
|
6108
|
+
#undef FA_TYPES_F32
|
|
5161
6109
|
|
|
5162
6110
|
constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
|
|
5163
6111
|
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
|
|
5164
6112
|
constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
|
|
5165
6113
|
constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
|
|
6114
|
+
constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
|
|
5166
6115
|
|
|
5167
6116
|
//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
|
|
5168
6117
|
//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
|
|
@@ -5189,9 +6138,9 @@ template<
|
|
|
5189
6138
|
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
|
5190
6139
|
short DK, // K head size
|
|
5191
6140
|
short DV, // V head size
|
|
5192
|
-
short NE
|
|
5193
|
-
short Q
|
|
5194
|
-
short C
|
|
6141
|
+
short NE, // head elements per thread
|
|
6142
|
+
short Q, // queries per threadgroup
|
|
6143
|
+
short C, // cache items per threadgroup
|
|
5195
6144
|
short NSG> // number of simd groups
|
|
5196
6145
|
void kernel_flash_attn_ext_vec_impl(
|
|
5197
6146
|
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
|
@@ -5200,6 +6149,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
5200
6149
|
device const char * v,
|
|
5201
6150
|
device const char * mask,
|
|
5202
6151
|
device const char * sinks,
|
|
6152
|
+
device const char * pad,
|
|
5203
6153
|
device char * dst,
|
|
5204
6154
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
|
5205
6155
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
@@ -5305,12 +6255,38 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
5305
6255
|
|
|
5306
6256
|
// loop over the KV cache
|
|
5307
6257
|
// each simdgroup handles blocks of Q rows and C columns
|
|
5308
|
-
for (int ic0 =
|
|
5309
|
-
|
|
6258
|
+
for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
|
|
6259
|
+
int ic = ic0*C;
|
|
5310
6260
|
if (ic >= args.ne11) {
|
|
5311
6261
|
break;
|
|
5312
6262
|
}
|
|
5313
6263
|
|
|
6264
|
+
// the last partial chunk uses the pad buffer as source
|
|
6265
|
+
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
|
|
6266
|
+
k = pad;
|
|
6267
|
+
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
|
6268
|
+
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
|
6269
|
+
|
|
6270
|
+
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
|
6271
|
+
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
|
6272
|
+
|
|
6273
|
+
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
|
|
6274
|
+
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
|
6275
|
+
|
|
6276
|
+
if (!FC_flash_attn_ext_vec_has_mask) {
|
|
6277
|
+
if (ic + tiisg >= args.ne11) {
|
|
6278
|
+
sm[tiisg] = -MAXHALF;
|
|
6279
|
+
}
|
|
6280
|
+
} else {
|
|
6281
|
+
pm = (device const half *) (mask) +
|
|
6282
|
+
iq1*C +
|
|
6283
|
+
(iq2%args.ne32)*(C*args.ne31) +
|
|
6284
|
+
(iq3%args.ne33)*(C*args.ne31*args.ne32);
|
|
6285
|
+
}
|
|
6286
|
+
|
|
6287
|
+
ic = 0;
|
|
6288
|
+
}
|
|
6289
|
+
|
|
5314
6290
|
if (FC_flash_attn_ext_vec_has_mask) {
|
|
5315
6291
|
sm[tiisg] = pm[ic + tiisg];
|
|
5316
6292
|
}
|
|
@@ -5322,7 +6298,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
5322
6298
|
|
|
5323
6299
|
// Q*K^T
|
|
5324
6300
|
{
|
|
5325
|
-
device const k4_t * pk4 = (device const k4_t *) (
|
|
6301
|
+
device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
|
|
5326
6302
|
threadgroup const q4_t * pq4 = sq4;
|
|
5327
6303
|
|
|
5328
6304
|
pk4 += ty*NS10/4 + tx;
|
|
@@ -5337,7 +6313,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
5337
6313
|
mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
|
|
5338
6314
|
}
|
|
5339
6315
|
} else {
|
|
5340
|
-
device const kd4_t * pk = (device const kd4_t *) (
|
|
6316
|
+
device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
|
|
5341
6317
|
|
|
5342
6318
|
k4_t mk;
|
|
5343
6319
|
|
|
@@ -5435,7 +6411,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
5435
6411
|
}
|
|
5436
6412
|
|
|
5437
6413
|
if (is_same<vd4_t, v4_t>::value) {
|
|
5438
|
-
device const v4_t * pv4 = (device const v4_t *) (
|
|
6414
|
+
device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
|
|
5439
6415
|
|
|
5440
6416
|
pv4 += ty*NS20/4 + tx;
|
|
5441
6417
|
|
|
@@ -5448,7 +6424,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
5448
6424
|
}
|
|
5449
6425
|
} else {
|
|
5450
6426
|
FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
|
|
5451
|
-
device const vd4_t * pv4 = (device const vd4_t *) (
|
|
6427
|
+
device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
|
|
5452
6428
|
|
|
5453
6429
|
FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
|
|
5454
6430
|
const short i = ii*NL + tx;
|
|
@@ -5573,7 +6549,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
5573
6549
|
device float4 * dst4 = (device float4 *) dst;
|
|
5574
6550
|
device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
|
|
5575
6551
|
|
|
5576
|
-
const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
|
|
6552
|
+
const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
|
|
5577
6553
|
|
|
5578
6554
|
// interleave the workgroup data
|
|
5579
6555
|
for (short i = tiisg; i < DV4; i += NW) {
|
|
@@ -5611,8 +6587,8 @@ template<
|
|
|
5611
6587
|
short DK, // K head size
|
|
5612
6588
|
short DV, // V head size
|
|
5613
6589
|
short NE = 4, // head elements per thread
|
|
5614
|
-
short Q =
|
|
5615
|
-
short C =
|
|
6590
|
+
short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
|
|
6591
|
+
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
|
|
5616
6592
|
kernel void kernel_flash_attn_ext_vec(
|
|
5617
6593
|
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
|
5618
6594
|
device const char * q,
|
|
@@ -5620,13 +6596,14 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
5620
6596
|
device const char * v,
|
|
5621
6597
|
device const char * mask,
|
|
5622
6598
|
device const char * sinks,
|
|
6599
|
+
device const char * pad,
|
|
5623
6600
|
device char * dst,
|
|
5624
6601
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
|
5625
6602
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5626
6603
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
5627
6604
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5628
6605
|
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
|
|
5629
|
-
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
|
|
6606
|
+
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
|
5630
6607
|
switch (FC_flash_attn_ext_vec_nsg) {
|
|
5631
6608
|
// note: disabled cases to reduce library load time
|
|
5632
6609
|
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
|
@@ -5651,79 +6628,106 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
5651
6628
|
float, float4, \
|
|
5652
6629
|
float4
|
|
5653
6630
|
|
|
6631
|
+
#define FA_TYPES_F32 \
|
|
6632
|
+
half4, \
|
|
6633
|
+
float4, \
|
|
6634
|
+
float4, \
|
|
6635
|
+
float, \
|
|
6636
|
+
float, float4, \
|
|
6637
|
+
float4
|
|
6638
|
+
|
|
5654
6639
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
|
5655
6640
|
|
|
5656
|
-
template [[host_name("
|
|
6641
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
|
|
6642
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
|
|
5657
6643
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5658
|
-
template [[host_name("
|
|
6644
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
|
|
5659
6645
|
#endif
|
|
5660
|
-
template [[host_name("
|
|
5661
|
-
template [[host_name("
|
|
5662
|
-
template [[host_name("
|
|
5663
|
-
template [[host_name("
|
|
5664
|
-
template [[host_name("
|
|
5665
|
-
|
|
5666
|
-
template [[host_name("
|
|
6646
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
|
|
6647
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
|
|
6648
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
|
|
6649
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
|
|
6650
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
|
|
6651
|
+
|
|
6652
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
|
|
6653
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
|
5667
6654
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5668
|
-
template [[host_name("
|
|
6655
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
|
5669
6656
|
#endif
|
|
5670
|
-
template [[host_name("
|
|
5671
|
-
template [[host_name("
|
|
5672
|
-
template [[host_name("
|
|
5673
|
-
template [[host_name("
|
|
5674
|
-
template [[host_name("
|
|
5675
|
-
|
|
5676
|
-
template [[host_name("
|
|
6657
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
|
|
6658
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
|
|
6659
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
|
|
6660
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
|
6661
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
|
6662
|
+
|
|
6663
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
|
|
6664
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
|
5677
6665
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5678
|
-
template [[host_name("
|
|
6666
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
|
5679
6667
|
#endif
|
|
5680
|
-
template [[host_name("
|
|
5681
|
-
template [[host_name("
|
|
5682
|
-
template [[host_name("
|
|
5683
|
-
template [[host_name("
|
|
5684
|
-
template [[host_name("
|
|
5685
|
-
|
|
5686
|
-
template [[host_name("
|
|
6668
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
|
|
6669
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
|
|
6670
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
|
|
6671
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
|
6672
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
|
6673
|
+
|
|
6674
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
|
|
6675
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
|
5687
6676
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5688
|
-
template [[host_name("
|
|
6677
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
|
5689
6678
|
#endif
|
|
5690
|
-
template [[host_name("
|
|
5691
|
-
template [[host_name("
|
|
5692
|
-
template [[host_name("
|
|
5693
|
-
template [[host_name("
|
|
5694
|
-
template [[host_name("
|
|
5695
|
-
|
|
5696
|
-
template [[host_name("
|
|
6679
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
|
|
6680
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
|
|
6681
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
|
|
6682
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
|
6683
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
|
6684
|
+
|
|
6685
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
|
|
6686
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
|
5697
6687
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5698
|
-
template [[host_name("
|
|
6688
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
|
5699
6689
|
#endif
|
|
5700
|
-
template [[host_name("
|
|
5701
|
-
template [[host_name("
|
|
5702
|
-
template [[host_name("
|
|
5703
|
-
template [[host_name("
|
|
5704
|
-
template [[host_name("
|
|
5705
|
-
|
|
5706
|
-
template [[host_name("
|
|
6690
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
|
|
6691
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
|
|
6692
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
|
|
6693
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
|
6694
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
|
6695
|
+
|
|
6696
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
|
|
6697
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
|
5707
6698
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5708
|
-
template [[host_name("
|
|
6699
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
|
5709
6700
|
#endif
|
|
5710
|
-
template [[host_name("
|
|
5711
|
-
template [[host_name("
|
|
5712
|
-
template [[host_name("
|
|
5713
|
-
template [[host_name("
|
|
5714
|
-
template [[host_name("
|
|
5715
|
-
|
|
5716
|
-
template [[host_name("
|
|
6701
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
|
|
6702
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
|
|
6703
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
|
|
6704
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
|
6705
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
|
6706
|
+
|
|
6707
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
|
|
6708
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
|
5717
6709
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5718
|
-
template [[host_name("
|
|
6710
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
|
5719
6711
|
#endif
|
|
5720
|
-
template [[host_name("
|
|
5721
|
-
template [[host_name("
|
|
5722
|
-
template [[host_name("
|
|
5723
|
-
template [[host_name("
|
|
5724
|
-
template [[host_name("
|
|
6712
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
|
|
6713
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
|
|
6714
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
|
|
6715
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
|
6716
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
|
6717
|
+
|
|
6718
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
|
|
6719
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
|
6720
|
+
#if defined(GGML_METAL_HAS_BF16)
|
|
6721
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
|
6722
|
+
#endif
|
|
6723
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
|
6724
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
|
6725
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
|
6726
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
|
6727
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
|
5725
6728
|
|
|
5726
6729
|
#undef FA_TYPES
|
|
6730
|
+
#undef FA_TYPES_F32
|
|
5727
6731
|
|
|
5728
6732
|
constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
|
|
5729
6733
|
constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
|
|
@@ -5750,7 +6754,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
|
|
|
5750
6754
|
const float m = simd_max(M);
|
|
5751
6755
|
const float ms = exp(M - m);
|
|
5752
6756
|
|
|
5753
|
-
S =
|
|
6757
|
+
S = simd_sum(S*ms);
|
|
6758
|
+
S = S == 0.0f ? 0.0f : 1.0f/S;
|
|
5754
6759
|
|
|
5755
6760
|
const short DV4 = DV/4;
|
|
5756
6761
|
|
|
@@ -5770,21 +6775,17 @@ kernel void kernel_flash_attn_ext_vec_reduce(
|
|
|
5770
6775
|
}
|
|
5771
6776
|
|
|
5772
6777
|
template<typename T0, typename T1>
|
|
5773
|
-
kernel void
|
|
6778
|
+
kernel void kernel_cpy_t_t(
|
|
5774
6779
|
constant ggml_metal_kargs_cpy & args,
|
|
5775
6780
|
device const char * src0,
|
|
5776
6781
|
device char * dst,
|
|
5777
6782
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5778
|
-
|
|
5779
|
-
ushort3
|
|
5780
|
-
ushort3 tptg[[threads_per_threadgroup]]) {
|
|
6783
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
6784
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5781
6785
|
const int i03 = tgpig[2];
|
|
5782
6786
|
const int i02 = tgpig[1];
|
|
5783
|
-
const int i01 = tgpig[0]
|
|
5784
|
-
|
|
5785
|
-
if (i01 >= args.ne01) {
|
|
5786
|
-
return;
|
|
5787
|
-
}
|
|
6787
|
+
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
|
6788
|
+
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
5788
6789
|
|
|
5789
6790
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
5790
6791
|
|
|
@@ -5795,190 +6796,71 @@ kernel void kernel_cpy(
|
|
|
5795
6796
|
|
|
5796
6797
|
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
5797
6798
|
|
|
5798
|
-
for (int64_t i00 = tiitg%
|
|
6799
|
+
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
|
|
5799
6800
|
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
5800
6801
|
dst_data[i00] = (T1) src[0];
|
|
6802
|
+
break;
|
|
5801
6803
|
}
|
|
5802
6804
|
}
|
|
5803
6805
|
|
|
5804
|
-
typedef decltype(
|
|
6806
|
+
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
|
|
5805
6807
|
|
|
5806
|
-
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t
|
|
5807
|
-
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t
|
|
5808
|
-
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t
|
|
5809
|
-
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t
|
|
6808
|
+
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
|
|
6809
|
+
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
|
|
6810
|
+
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
|
|
6811
|
+
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
|
|
6812
|
+
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
|
|
5810
6813
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5811
|
-
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t
|
|
6814
|
+
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
|
|
5812
6815
|
#endif
|
|
5813
|
-
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t
|
|
5814
|
-
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t
|
|
6816
|
+
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
|
|
6817
|
+
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
|
|
5815
6818
|
#if defined(GGML_METAL_HAS_BF16)
|
|
5816
|
-
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t
|
|
5817
|
-
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t
|
|
6819
|
+
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
|
|
6820
|
+
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
|
|
5818
6821
|
#endif
|
|
5819
6822
|
|
|
5820
|
-
|
|
5821
|
-
|
|
6823
|
+
template<short QK,
|
|
6824
|
+
typename block_q,
|
|
6825
|
+
void (*quantize_func)(device const float *, device block_q &)>
|
|
6826
|
+
kernel void kernel_cpy_f32_q(
|
|
5822
6827
|
constant ggml_metal_kargs_cpy & args,
|
|
5823
6828
|
device const char * src0,
|
|
5824
|
-
device
|
|
6829
|
+
device char * dst,
|
|
5825
6830
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5826
|
-
|
|
5827
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5828
|
-
const int i03 = tgpig[2];
|
|
5829
|
-
const int i02 = tgpig[1];
|
|
5830
|
-
const int i01 = tgpig[0];
|
|
5831
|
-
|
|
5832
|
-
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
5833
|
-
|
|
5834
|
-
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
5835
|
-
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
5836
|
-
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
5837
|
-
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
|
|
5838
|
-
|
|
5839
|
-
device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
5840
|
-
|
|
5841
|
-
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
|
5842
|
-
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
5843
|
-
|
|
5844
|
-
quantize_q8_0(src, dst_data[i00/QK8_0]);
|
|
5845
|
-
}
|
|
5846
|
-
}
|
|
5847
|
-
|
|
5848
|
-
kernel void kernel_cpy_f32_q4_0(
|
|
5849
|
-
constant ggml_metal_kargs_cpy & args,
|
|
5850
|
-
device const char * src0,
|
|
5851
|
-
device char * dst,
|
|
5852
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5853
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
5854
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5855
|
-
const int i03 = tgpig[2];
|
|
5856
|
-
const int i02 = tgpig[1];
|
|
5857
|
-
const int i01 = tgpig[0];
|
|
5858
|
-
|
|
5859
|
-
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
5860
|
-
|
|
5861
|
-
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
5862
|
-
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
5863
|
-
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
5864
|
-
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
|
|
5865
|
-
|
|
5866
|
-
device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
5867
|
-
|
|
5868
|
-
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
|
5869
|
-
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
5870
|
-
|
|
5871
|
-
quantize_q4_0(src, dst_data[i00/QK4_0]);
|
|
5872
|
-
}
|
|
5873
|
-
}
|
|
5874
|
-
|
|
5875
|
-
kernel void kernel_cpy_f32_q4_1(
|
|
5876
|
-
constant ggml_metal_kargs_cpy & args,
|
|
5877
|
-
device const char * src0,
|
|
5878
|
-
device char * dst,
|
|
5879
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5880
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
6831
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
5881
6832
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5882
6833
|
const int i03 = tgpig[2];
|
|
5883
6834
|
const int i02 = tgpig[1];
|
|
5884
|
-
const int i01 = tgpig[0];
|
|
6835
|
+
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
|
6836
|
+
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
5885
6837
|
|
|
5886
6838
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
5887
6839
|
|
|
5888
6840
|
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
5889
6841
|
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
5890
6842
|
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
5891
|
-
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/
|
|
5892
|
-
|
|
5893
|
-
device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
5894
|
-
|
|
5895
|
-
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
|
5896
|
-
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
5897
|
-
|
|
5898
|
-
quantize_q4_1(src, dst_data[i00/QK4_1]);
|
|
5899
|
-
}
|
|
5900
|
-
}
|
|
5901
|
-
|
|
5902
|
-
kernel void kernel_cpy_f32_q5_0(
|
|
5903
|
-
constant ggml_metal_kargs_cpy & args,
|
|
5904
|
-
device const char * src0,
|
|
5905
|
-
device char * dst,
|
|
5906
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5907
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
5908
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5909
|
-
const int i03 = tgpig[2];
|
|
5910
|
-
const int i02 = tgpig[1];
|
|
5911
|
-
const int i01 = tgpig[0];
|
|
6843
|
+
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
|
5912
6844
|
|
|
5913
|
-
|
|
6845
|
+
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
5914
6846
|
|
|
5915
|
-
|
|
5916
|
-
|
|
5917
|
-
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
5918
|
-
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
|
|
5919
|
-
|
|
5920
|
-
device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
6847
|
+
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
|
|
6848
|
+
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
|
|
5921
6849
|
|
|
5922
|
-
|
|
5923
|
-
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
6850
|
+
quantize_func(src, dst_data[i00]);
|
|
5924
6851
|
|
|
5925
|
-
|
|
6852
|
+
break;
|
|
5926
6853
|
}
|
|
5927
6854
|
}
|
|
5928
6855
|
|
|
5929
|
-
|
|
5930
|
-
constant ggml_metal_kargs_cpy & args,
|
|
5931
|
-
device const char * src0,
|
|
5932
|
-
device char * dst,
|
|
5933
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5934
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
5935
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5936
|
-
const int i03 = tgpig[2];
|
|
5937
|
-
const int i02 = tgpig[1];
|
|
5938
|
-
const int i01 = tgpig[0];
|
|
5939
|
-
|
|
5940
|
-
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
6856
|
+
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
|
5941
6857
|
|
|
5942
|
-
|
|
5943
|
-
|
|
5944
|
-
|
|
5945
|
-
|
|
5946
|
-
|
|
5947
|
-
|
|
5948
|
-
|
|
5949
|
-
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
|
5950
|
-
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
5951
|
-
|
|
5952
|
-
quantize_q5_1(src, dst_data[i00/QK5_1]);
|
|
5953
|
-
}
|
|
5954
|
-
}
|
|
5955
|
-
|
|
5956
|
-
kernel void kernel_cpy_f32_iq4_nl(
|
|
5957
|
-
constant ggml_metal_kargs_cpy & args,
|
|
5958
|
-
device const char * src0,
|
|
5959
|
-
device char * dst,
|
|
5960
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5961
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
5962
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5963
|
-
const int i03 = tgpig[2];
|
|
5964
|
-
const int i02 = tgpig[1];
|
|
5965
|
-
const int i01 = tgpig[0];
|
|
5966
|
-
|
|
5967
|
-
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
5968
|
-
|
|
5969
|
-
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
5970
|
-
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
5971
|
-
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
5972
|
-
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
|
|
5973
|
-
|
|
5974
|
-
device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
5975
|
-
|
|
5976
|
-
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
|
5977
|
-
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
5978
|
-
|
|
5979
|
-
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
|
|
5980
|
-
}
|
|
5981
|
-
}
|
|
6858
|
+
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
|
6859
|
+
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
|
6860
|
+
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
|
6861
|
+
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
|
6862
|
+
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
|
|
6863
|
+
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
|
|
5982
6864
|
|
|
5983
6865
|
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
|
5984
6866
|
kernel void kernel_cpy_q_f32(
|
|
@@ -5986,11 +6868,12 @@ kernel void kernel_cpy_q_f32(
|
|
|
5986
6868
|
device const char * src0,
|
|
5987
6869
|
device char * dst,
|
|
5988
6870
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5989
|
-
|
|
6871
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
5990
6872
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
5991
6873
|
const int i03 = tgpig[2];
|
|
5992
6874
|
const int i02 = tgpig[1];
|
|
5993
|
-
const int i01 = tgpig[0];
|
|
6875
|
+
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
|
6876
|
+
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
5994
6877
|
|
|
5995
6878
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
5996
6879
|
|
|
@@ -6002,10 +6885,12 @@ kernel void kernel_cpy_q_f32(
|
|
|
6002
6885
|
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
6003
6886
|
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
6004
6887
|
|
|
6005
|
-
for (int64_t i00 =
|
|
6888
|
+
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
|
|
6006
6889
|
T4x4 temp;
|
|
6007
6890
|
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
|
6008
6891
|
dst_data[i00] = temp;
|
|
6892
|
+
|
|
6893
|
+
break;
|
|
6009
6894
|
}
|
|
6010
6895
|
}
|
|
6011
6896
|
|
|
@@ -7458,7 +8343,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|
|
7458
8343
|
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
7459
8344
|
}
|
|
7460
8345
|
|
|
7461
|
-
template<int
|
|
8346
|
+
template<int NR0, typename args_t>
|
|
7462
8347
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
|
7463
8348
|
args_t args,
|
|
7464
8349
|
device const char * src0,
|
|
@@ -7471,13 +8356,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
7471
8356
|
const short NSG = FC_mul_mv_nsg;
|
|
7472
8357
|
|
|
7473
8358
|
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
|
7474
|
-
const int nb = args.ne00/QK4_NL;
|
|
7475
8359
|
|
|
7476
8360
|
const int r0 = tgpig.x;
|
|
7477
8361
|
const int r1 = tgpig.y;
|
|
7478
8362
|
const int im = tgpig.z;
|
|
7479
8363
|
|
|
7480
|
-
const int first_row = (r0 * NSG + sgitg) *
|
|
8364
|
+
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
7481
8365
|
|
|
7482
8366
|
const uint i12 = im%args.ne12;
|
|
7483
8367
|
const uint i13 = im/args.ne12;
|
|
@@ -7488,6 +8372,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
7488
8372
|
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
|
7489
8373
|
device const float * y = (device const float *) (src1 + offset1);
|
|
7490
8374
|
|
|
8375
|
+
const int nb = args.ne00/QK4_NL;
|
|
8376
|
+
const int ns01 = args.nb01/args.nb00;
|
|
8377
|
+
|
|
7491
8378
|
const short ix = tiisg/2; // 0...15
|
|
7492
8379
|
const short it = tiisg%2; // 0 or 1
|
|
7493
8380
|
|
|
@@ -7495,24 +8382,25 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
7495
8382
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
7496
8383
|
|
|
7497
8384
|
float4 yl[4];
|
|
7498
|
-
float sumf[
|
|
8385
|
+
float sumf[NR0]={0.f};
|
|
7499
8386
|
|
|
7500
|
-
device const float * yb = y + ix
|
|
8387
|
+
device const float * yb = y + ix*QK4_NL + it*8;
|
|
7501
8388
|
|
|
7502
8389
|
uint32_t aux32[2];
|
|
7503
8390
|
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
|
7504
8391
|
|
|
7505
8392
|
float4 qf1, qf2;
|
|
7506
8393
|
|
|
7507
|
-
|
|
8394
|
+
// [TAG_MUL_MV_WEIRD]
|
|
8395
|
+
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
|
|
7508
8396
|
device const float4 * y4 = (device const float4 *)yb;
|
|
7509
8397
|
yl[0] = y4[0];
|
|
7510
8398
|
yl[1] = y4[4];
|
|
7511
8399
|
yl[2] = y4[1];
|
|
7512
8400
|
yl[3] = y4[5];
|
|
7513
8401
|
|
|
7514
|
-
for (short row = 0; row <
|
|
7515
|
-
device const block_iq4_nl & xb = x[row*
|
|
8402
|
+
for (short row = 0; row < NR0; row++) {
|
|
8403
|
+
device const block_iq4_nl & xb = x[row*ns01 + ib];
|
|
7516
8404
|
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
|
7517
8405
|
|
|
7518
8406
|
float4 acc1 = {0.f}, acc2 = {0.f};
|
|
@@ -7543,7 +8431,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
7543
8431
|
|
|
7544
8432
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
7545
8433
|
|
|
7546
|
-
for (int row = 0; row <
|
|
8434
|
+
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
|
7547
8435
|
float sum_all = simd_sum(sumf[row]);
|
|
7548
8436
|
if (tiisg == 0) {
|
|
7549
8437
|
dst_f32[first_row + row] = sum_all;
|
|
@@ -7565,7 +8453,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|
|
7565
8453
|
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
7566
8454
|
}
|
|
7567
8455
|
|
|
7568
|
-
template<int
|
|
8456
|
+
template<int NR0, typename args_t>
|
|
7569
8457
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
|
7570
8458
|
args_t args,
|
|
7571
8459
|
device const char * src0,
|
|
@@ -7578,12 +8466,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
7578
8466
|
const short NSG = FC_mul_mv_nsg;
|
|
7579
8467
|
|
|
7580
8468
|
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
|
7581
|
-
const int nb = args.ne00/QK_K;
|
|
7582
8469
|
|
|
7583
8470
|
const int r0 = tgpig.x;
|
|
7584
8471
|
const int r1 = tgpig.y;
|
|
7585
8472
|
const int im = tgpig.z;
|
|
7586
|
-
const int first_row = (r0 * NSG + sgitg) *
|
|
8473
|
+
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
7587
8474
|
|
|
7588
8475
|
const uint i12 = im%args.ne12;
|
|
7589
8476
|
const uint i13 = im/args.ne12;
|
|
@@ -7594,6 +8481,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
7594
8481
|
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
|
7595
8482
|
device const float * y = (device const float *) (src1 + offset1);
|
|
7596
8483
|
|
|
8484
|
+
const int nb = args.ne00/QK_K;
|
|
8485
|
+
const int ns01 = args.nb01/args.nb00;
|
|
8486
|
+
|
|
7597
8487
|
const short ix = tiisg/16; // 0 or 1
|
|
7598
8488
|
const short it = tiisg%16; // 0...15
|
|
7599
8489
|
const short ib = it/2;
|
|
@@ -7603,7 +8493,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
7603
8493
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
7604
8494
|
|
|
7605
8495
|
float4 yl[4];
|
|
7606
|
-
float sumf[
|
|
8496
|
+
float sumf[NR0]={0.f};
|
|
7607
8497
|
|
|
7608
8498
|
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
|
7609
8499
|
|
|
@@ -7612,15 +8502,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
7612
8502
|
|
|
7613
8503
|
float4 qf1, qf2;
|
|
7614
8504
|
|
|
7615
|
-
|
|
8505
|
+
// [TAG_MUL_MV_WEIRD]
|
|
8506
|
+
for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
|
|
7616
8507
|
device const float4 * y4 = (device const float4 *)yb;
|
|
7617
8508
|
yl[0] = y4[0];
|
|
7618
8509
|
yl[1] = y4[4];
|
|
7619
8510
|
yl[2] = y4[1];
|
|
7620
8511
|
yl[3] = y4[5];
|
|
7621
8512
|
|
|
7622
|
-
for (short row = 0; row <
|
|
7623
|
-
device const block_iq4_xs & xb = x[row*
|
|
8513
|
+
for (short row = 0; row < NR0; ++row) {
|
|
8514
|
+
device const block_iq4_xs & xb = x[row*ns01 + ibl];
|
|
7624
8515
|
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
|
7625
8516
|
|
|
7626
8517
|
float4 acc1 = {0.f}, acc2 = {0.f};
|
|
@@ -7650,7 +8541,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
7650
8541
|
|
|
7651
8542
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
7652
8543
|
|
|
7653
|
-
for (int row = 0; row <
|
|
8544
|
+
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
|
7654
8545
|
float sum_all = simd_sum(sumf[row]);
|
|
7655
8546
|
if (tiisg == 0) {
|
|
7656
8547
|
dst_f32[first_row + row] = sum_all;
|
|
@@ -7672,7 +8563,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
|
7672
8563
|
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
7673
8564
|
}
|
|
7674
8565
|
|
|
7675
|
-
template<int
|
|
8566
|
+
template<int NR0, typename args_t>
|
|
7676
8567
|
void kernel_mul_mv_mxfp4_f32_impl(
|
|
7677
8568
|
args_t args,
|
|
7678
8569
|
device const char * src0,
|
|
@@ -7685,13 +8576,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
|
|
7685
8576
|
const short NSG = FC_mul_mv_nsg;
|
|
7686
8577
|
|
|
7687
8578
|
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
|
7688
|
-
const int nb = args.ne00/QK_MXFP4;
|
|
7689
8579
|
|
|
7690
8580
|
const int r0 = tgpig.x;
|
|
7691
8581
|
const int r1 = tgpig.y;
|
|
7692
8582
|
const int im = tgpig.z;
|
|
7693
8583
|
|
|
7694
|
-
const int first_row = (r0 * NSG + sgitg) *
|
|
8584
|
+
const int first_row = (r0 * NSG + sgitg) * NR0;
|
|
7695
8585
|
|
|
7696
8586
|
const uint i12 = im%args.ne12;
|
|
7697
8587
|
const uint i13 = im/args.ne12;
|
|
@@ -7702,6 +8592,9 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
|
|
7702
8592
|
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
|
|
7703
8593
|
device const float * y = (device const float *) (src1 + offset1);
|
|
7704
8594
|
|
|
8595
|
+
const int nb = args.ne00/QK_MXFP4;
|
|
8596
|
+
const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
|
|
8597
|
+
|
|
7705
8598
|
const short ix = tiisg/2; // 0...15
|
|
7706
8599
|
const short it = tiisg%2; // 0 or 1
|
|
7707
8600
|
|
|
@@ -7709,20 +8602,22 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
|
|
7709
8602
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
7710
8603
|
|
|
7711
8604
|
float4 yl[4];
|
|
7712
|
-
float sumf[
|
|
8605
|
+
float sumf[NR0]={0.f};
|
|
7713
8606
|
|
|
7714
|
-
device const float * yb = y + ix
|
|
8607
|
+
device const float * yb = y + ix*QK_MXFP4 + it*8;
|
|
8608
|
+
|
|
8609
|
+
// note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
|
|
8610
|
+
// no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
|
|
8611
|
+
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
|
|
8612
|
+
device const float4 * y4 = (device const float4 *) yb;
|
|
7715
8613
|
|
|
7716
|
-
for (int ib = ix; ib < nb; ib += 16) {
|
|
7717
|
-
device const float4 * y4 = (device const float4 *)yb;
|
|
7718
8614
|
yl[0] = y4[0];
|
|
7719
8615
|
yl[1] = y4[4];
|
|
7720
8616
|
yl[2] = y4[1];
|
|
7721
8617
|
yl[3] = y4[5];
|
|
7722
8618
|
|
|
7723
|
-
|
|
7724
|
-
|
|
7725
|
-
device const block_mxfp4 & xb = x[row*nb + ib];
|
|
8619
|
+
FOR_UNROLL (short row = 0; row < NR0; row++) {
|
|
8620
|
+
device const block_mxfp4 & xb = x[row*ns01 + ib];
|
|
7726
8621
|
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
|
|
7727
8622
|
|
|
7728
8623
|
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
|
|
@@ -7740,7 +8635,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
|
|
7740
8635
|
|
|
7741
8636
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
7742
8637
|
|
|
7743
|
-
for (int row = 0; row <
|
|
8638
|
+
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
|
7744
8639
|
float sum_all = simd_sum(sumf[row]);
|
|
7745
8640
|
if (tiisg == 0) {
|
|
7746
8641
|
dst_f32[first_row + row] = sum_all;
|
|
@@ -7765,66 +8660,60 @@ kernel void kernel_mul_mv_mxfp4_f32(
|
|
|
7765
8660
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
7766
8661
|
kernel void kernel_get_rows_q(
|
|
7767
8662
|
constant ggml_metal_kargs_get_rows & args,
|
|
7768
|
-
device const
|
|
7769
|
-
device const
|
|
7770
|
-
device
|
|
7771
|
-
uint3
|
|
7772
|
-
|
|
7773
|
-
|
|
7774
|
-
const
|
|
7775
|
-
const
|
|
8663
|
+
device const void * src0,
|
|
8664
|
+
device const void * src1,
|
|
8665
|
+
device void * dst,
|
|
8666
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
8667
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
8668
|
+
ushort3 ntg [[threads_per_threadgroup]]) {
|
|
8669
|
+
const int32_t iw0 = tgpig.x/args.ne10;
|
|
8670
|
+
const int32_t i10 = tgpig.x%args.ne10;
|
|
8671
|
+
const int32_t i11 = tgpig.y;
|
|
8672
|
+
const int32_t i12 = tgpig.z;
|
|
8673
|
+
|
|
8674
|
+
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
|
7776
8675
|
|
|
7777
|
-
const
|
|
8676
|
+
const int32_t i02 = i11;
|
|
8677
|
+
const int32_t i03 = i12;
|
|
7778
8678
|
|
|
7779
|
-
const
|
|
8679
|
+
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
|
8680
|
+
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
|
7780
8681
|
|
|
7781
|
-
for (
|
|
8682
|
+
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
|
7782
8683
|
float4x4 temp;
|
|
7783
|
-
dequantize_func(
|
|
7784
|
-
|
|
8684
|
+
dequantize_func(psrc + ind/nl, ind%nl, temp);
|
|
8685
|
+
pdst[ind] = temp;
|
|
8686
|
+
|
|
8687
|
+
break;
|
|
7785
8688
|
}
|
|
7786
8689
|
}
|
|
7787
8690
|
|
|
7788
|
-
template<typename T>
|
|
8691
|
+
template<typename T0, typename T>
|
|
7789
8692
|
kernel void kernel_get_rows_f(
|
|
7790
8693
|
constant ggml_metal_kargs_get_rows & args,
|
|
7791
|
-
device const
|
|
7792
|
-
device const
|
|
7793
|
-
device
|
|
7794
|
-
uint3
|
|
7795
|
-
|
|
7796
|
-
|
|
7797
|
-
const
|
|
7798
|
-
const
|
|
7799
|
-
|
|
7800
|
-
const
|
|
7801
|
-
|
|
7802
|
-
const int64_t i02 = i11;
|
|
8694
|
+
device const void * src0,
|
|
8695
|
+
device const void * src1,
|
|
8696
|
+
device void * dst,
|
|
8697
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
8698
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
8699
|
+
ushort3 ntg [[threads_per_threadgroup]]) {
|
|
8700
|
+
const int32_t iw0 = tgpig.x/args.ne10;
|
|
8701
|
+
const int32_t i10 = tgpig.x%args.ne10;
|
|
8702
|
+
const int32_t i11 = tgpig.y;
|
|
8703
|
+
const int32_t i12 = tgpig.z;
|
|
7803
8704
|
|
|
7804
|
-
|
|
7805
|
-
(( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
|
|
7806
|
-
((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
|
|
7807
|
-
}
|
|
7808
|
-
}
|
|
8705
|
+
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
|
7809
8706
|
|
|
7810
|
-
|
|
7811
|
-
|
|
7812
|
-
device const void * src0,
|
|
7813
|
-
device const void * src1,
|
|
7814
|
-
device int32_t * dst,
|
|
7815
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
7816
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
|
7817
|
-
uint3 tptg [[threads_per_threadgroup]]) {
|
|
7818
|
-
const int64_t i10 = tgpig.x;
|
|
7819
|
-
const int64_t i11 = tgpig.y;
|
|
8707
|
+
const int32_t i02 = i11;
|
|
8708
|
+
const int32_t i03 = i12;
|
|
7820
8709
|
|
|
7821
|
-
|
|
8710
|
+
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
|
8711
|
+
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
|
7822
8712
|
|
|
7823
|
-
|
|
8713
|
+
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
|
8714
|
+
pdst[ind] = psrc[ind];
|
|
7824
8715
|
|
|
7825
|
-
|
|
7826
|
-
(( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
|
|
7827
|
-
((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
|
|
8716
|
+
break;
|
|
7828
8717
|
}
|
|
7829
8718
|
}
|
|
7830
8719
|
|
|
@@ -7893,17 +8782,6 @@ kernel void kernel_set_rows_f(
|
|
|
7893
8782
|
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
|
7894
8783
|
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
|
7895
8784
|
|
|
7896
|
-
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
7897
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
7898
|
-
#define BLOCK_SIZE_K 32
|
|
7899
|
-
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
|
7900
|
-
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
7901
|
-
#define THREAD_PER_BLOCK 128
|
|
7902
|
-
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
|
7903
|
-
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
|
7904
|
-
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
|
7905
|
-
#define SG_MAT_ROW 8
|
|
7906
|
-
|
|
7907
8785
|
// each block_q contains 16*nl weights
|
|
7908
8786
|
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
7909
8787
|
kernel void kernel_mul_mm(
|
|
@@ -7919,18 +8797,48 @@ kernel void kernel_mul_mm(
|
|
|
7919
8797
|
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
|
7920
8798
|
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
|
7921
8799
|
|
|
7922
|
-
|
|
7923
|
-
|
|
8800
|
+
threadgroup float * sc = (threadgroup float *)(shmem);
|
|
8801
|
+
|
|
8802
|
+
constexpr int NR0 = 64;
|
|
8803
|
+
constexpr int NR1 = 32;
|
|
8804
|
+
|
|
8805
|
+
constexpr int NK = 32;
|
|
8806
|
+
constexpr int NL0 = NK/16;
|
|
8807
|
+
constexpr int NL1 = NK/8;
|
|
8808
|
+
|
|
7924
8809
|
const int im = tgpig.z;
|
|
8810
|
+
const int r0 = tgpig.y*NR0;
|
|
8811
|
+
const int r1 = tgpig.x*NR1;
|
|
7925
8812
|
|
|
7926
8813
|
// if this block is of 64x32 shape or smaller
|
|
7927
|
-
const short
|
|
7928
|
-
const short
|
|
8814
|
+
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
|
|
8815
|
+
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
|
|
7929
8816
|
|
|
7930
8817
|
// a thread shouldn't load data outside of the matrix
|
|
7931
|
-
const short
|
|
7932
|
-
const short
|
|
8818
|
+
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
|
|
8819
|
+
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
|
|
8820
|
+
|
|
8821
|
+
const short il0 = (tiitg % NL0);
|
|
8822
|
+
|
|
8823
|
+
short il = il0;
|
|
8824
|
+
|
|
8825
|
+
const int i12 = im%args.ne12;
|
|
8826
|
+
const int i13 = im/args.ne12;
|
|
8827
|
+
|
|
8828
|
+
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
8829
|
+
const short offset1 = il0/nl;
|
|
8830
|
+
|
|
8831
|
+
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
|
7933
8832
|
|
|
8833
|
+
const short iy = 8*(tiitg % NL1);
|
|
8834
|
+
|
|
8835
|
+
device const T1 * y = (device const T1 *)(src1
|
|
8836
|
+
+ args.nb13*i13
|
|
8837
|
+
+ args.nb12*i12
|
|
8838
|
+
+ args.nb11*(r1 + lr1)
|
|
8839
|
+
+ args.nb10*iy);
|
|
8840
|
+
|
|
8841
|
+
#ifndef GGML_METAL_HAS_TENSOR
|
|
7934
8842
|
S0_8x8 ma[4];
|
|
7935
8843
|
S1_8x8 mb[2];
|
|
7936
8844
|
|
|
@@ -7939,36 +8847,104 @@ kernel void kernel_mul_mm(
|
|
|
7939
8847
|
for (short i = 0; i < 8; i++){
|
|
7940
8848
|
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
7941
8849
|
}
|
|
8850
|
+
#else
|
|
8851
|
+
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
|
|
8852
|
+
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
|
|
7942
8853
|
|
|
7943
|
-
|
|
8854
|
+
mpp::tensor_ops::matmul2d<
|
|
8855
|
+
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
|
8856
|
+
execution_simdgroups<4>> mm;
|
|
7944
8857
|
|
|
7945
|
-
|
|
7946
|
-
|
|
8858
|
+
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
|
|
8859
|
+
#endif
|
|
7947
8860
|
|
|
7948
|
-
|
|
7949
|
-
|
|
8861
|
+
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
|
8862
|
+
#ifndef GGML_METAL_HAS_TENSOR
|
|
8863
|
+
// load data and store to threadgroup memory
|
|
8864
|
+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
8865
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
7950
8866
|
|
|
7951
|
-
|
|
7952
|
-
|
|
8867
|
+
// no need for dequantization
|
|
8868
|
+
for (short i = 0; i < 16; i++) {
|
|
8869
|
+
const short sx = 2*il0 + i/8;
|
|
8870
|
+
const short sy = (tiitg/NL0)/8;
|
|
7953
8871
|
|
|
7954
|
-
|
|
8872
|
+
//const short lx = i%8;
|
|
8873
|
+
//const short ly = (tiitg/NL0)%8;
|
|
8874
|
+
const short lx = (tiitg/NL0)%8;
|
|
8875
|
+
const short ly = i%8;
|
|
7955
8876
|
|
|
7956
|
-
|
|
7957
|
-
|
|
7958
|
-
|
|
7959
|
-
|
|
7960
|
-
|
|
8877
|
+
const short ib = 8*sx + sy;
|
|
8878
|
+
|
|
8879
|
+
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
8880
|
+
}
|
|
8881
|
+
} else {
|
|
8882
|
+
S0_4x4 temp_a;
|
|
8883
|
+
dequantize_func(x, il, temp_a);
|
|
8884
|
+
|
|
8885
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8886
|
+
|
|
8887
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
8888
|
+
const short sx = 2*il0 + i/8;
|
|
8889
|
+
const short sy = (tiitg/NL0)/8;
|
|
8890
|
+
|
|
8891
|
+
//const short lx = i%8;
|
|
8892
|
+
//const short ly = (tiitg/NL0)%8;
|
|
8893
|
+
const short lx = (tiitg/NL0)%8;
|
|
8894
|
+
const short ly = i%8;
|
|
8895
|
+
|
|
8896
|
+
const short ib = 8*sx + sy;
|
|
8897
|
+
|
|
8898
|
+
// NOTE: this is massively slower.. WTF?
|
|
8899
|
+
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
|
|
8900
|
+
|
|
8901
|
+
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
|
|
8902
|
+
}
|
|
8903
|
+
}
|
|
8904
|
+
|
|
8905
|
+
if (FC_mul_mm_bc_inp) {
|
|
8906
|
+
for (short i = 0; i < 8; ++i) {
|
|
8907
|
+
const short sx = (tiitg%NL1);
|
|
8908
|
+
const short sy = (tiitg/NL1)/8;
|
|
8909
|
+
|
|
8910
|
+
const short lx = i;
|
|
8911
|
+
const short ly = (tiitg/NL1)%8;
|
|
8912
|
+
//const short lx = (tiitg/NL1)%8;
|
|
8913
|
+
//const short ly = i;
|
|
8914
|
+
|
|
8915
|
+
const short ib = 4*sx + sy;
|
|
8916
|
+
|
|
8917
|
+
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
|
8918
|
+
}
|
|
8919
|
+
} else {
|
|
8920
|
+
const short sx = (tiitg%NL1);
|
|
8921
|
+
const short sy = (tiitg/NL1)/8;
|
|
8922
|
+
|
|
8923
|
+
const short dx = sx;
|
|
8924
|
+
const short dy = sy;
|
|
7961
8925
|
|
|
7962
|
-
|
|
8926
|
+
const short ly = (tiitg/NL1)%8;
|
|
8927
|
+
|
|
8928
|
+
const short ib = 4*sx + sy;
|
|
8929
|
+
|
|
8930
|
+
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
8931
|
+
}
|
|
8932
|
+
#else
|
|
7963
8933
|
// load data and store to threadgroup memory
|
|
7964
8934
|
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
7965
8935
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
7966
8936
|
|
|
7967
8937
|
// no need for dequantization
|
|
7968
8938
|
for (short i = 0; i < 16; i++) {
|
|
7969
|
-
|
|
7970
|
-
|
|
7971
|
-
|
|
8939
|
+
const short sx = 2*il0 + i/8;
|
|
8940
|
+
const short sy = (tiitg/NL0)/8;
|
|
8941
|
+
|
|
8942
|
+
const short lx = i%8;
|
|
8943
|
+
const short ly = (tiitg/NL0)%8;
|
|
8944
|
+
//const short lx = (tiitg/NL0)%8;
|
|
8945
|
+
//const short ly = i%8;
|
|
8946
|
+
|
|
8947
|
+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
7972
8948
|
}
|
|
7973
8949
|
} else {
|
|
7974
8950
|
S0_4x4 temp_a;
|
|
@@ -7977,91 +8953,135 @@ kernel void kernel_mul_mm(
|
|
|
7977
8953
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
7978
8954
|
|
|
7979
8955
|
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
7980
|
-
|
|
7981
|
-
|
|
7982
|
-
|
|
8956
|
+
const short sx = 2*il0 + i/8;
|
|
8957
|
+
const short sy = (tiitg/NL0)/8;
|
|
8958
|
+
|
|
8959
|
+
const short lx = i%8;
|
|
8960
|
+
const short ly = (tiitg/NL0)%8;
|
|
8961
|
+
//const short lx = (tiitg/NL0)%8;
|
|
8962
|
+
//const short ly = i%8;
|
|
8963
|
+
|
|
8964
|
+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
|
|
7983
8965
|
}
|
|
7984
8966
|
}
|
|
7985
8967
|
|
|
7986
8968
|
if (FC_mul_mm_bc_inp) {
|
|
7987
8969
|
for (short i = 0; i < 8; ++i) {
|
|
7988
|
-
|
|
8970
|
+
const short sx = (tiitg%NL1);
|
|
8971
|
+
const short sy = (tiitg/NL1)/8;
|
|
8972
|
+
|
|
8973
|
+
const short lx = i;
|
|
8974
|
+
const short ly = (tiitg/NL1)%8;
|
|
8975
|
+
//const short lx = (tiitg/NL1)%8;
|
|
8976
|
+
//const short ly = i;
|
|
8977
|
+
|
|
8978
|
+
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
|
7989
8979
|
}
|
|
7990
8980
|
} else {
|
|
7991
|
-
|
|
8981
|
+
const short sx = (tiitg%NL1);
|
|
8982
|
+
const short sy = (tiitg/NL1)/8;
|
|
8983
|
+
|
|
8984
|
+
//const short lx = i;
|
|
8985
|
+
const short ly = (tiitg/NL1)%8;
|
|
8986
|
+
//const short lx = (tiitg/NL1)%8;
|
|
8987
|
+
//const short ly = i;
|
|
8988
|
+
|
|
8989
|
+
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
7992
8990
|
}
|
|
8991
|
+
#endif
|
|
7993
8992
|
|
|
7994
8993
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
7995
8994
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
|
7996
|
-
|
|
8995
|
+
|
|
8996
|
+
y += NK;
|
|
7997
8997
|
|
|
7998
8998
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
7999
8999
|
|
|
9000
|
+
#ifndef GGML_METAL_HAS_TENSOR
|
|
8000
9001
|
// load matrices from threadgroup memory and conduct outer products
|
|
8001
|
-
threadgroup const S0 * lsma = (sa +
|
|
8002
|
-
threadgroup const S1 * lsmb = (sb +
|
|
9002
|
+
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
|
9003
|
+
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
|
8003
9004
|
|
|
8004
|
-
|
|
8005
|
-
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
|
9005
|
+
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
|
8006
9006
|
simdgroup_barrier(mem_flags::mem_none);
|
|
8007
9007
|
|
|
8008
|
-
|
|
8009
|
-
|
|
8010
|
-
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
|
9008
|
+
FOR_UNROLL (short i = 0; i < 4; i++) {
|
|
9009
|
+
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
|
8011
9010
|
}
|
|
8012
9011
|
|
|
8013
|
-
|
|
8014
|
-
|
|
8015
|
-
|
|
9012
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
9013
|
+
|
|
9014
|
+
FOR_UNROLL (short i = 0; i < 2; i++) {
|
|
9015
|
+
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
|
8016
9016
|
}
|
|
8017
9017
|
|
|
8018
9018
|
simdgroup_barrier(mem_flags::mem_none);
|
|
8019
9019
|
|
|
8020
|
-
|
|
8021
|
-
for (short i = 0; i < 8; i++){
|
|
9020
|
+
FOR_UNROLL (short i = 0; i < 8; i++){
|
|
8022
9021
|
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
|
8023
9022
|
}
|
|
8024
9023
|
|
|
8025
|
-
lsma +=
|
|
8026
|
-
lsmb +=
|
|
9024
|
+
lsma += 8*64;
|
|
9025
|
+
lsmb += 4*64;
|
|
8027
9026
|
}
|
|
9027
|
+
#else
|
|
9028
|
+
auto sA = tA.slice(0, 0);
|
|
9029
|
+
auto sB = tB.slice(0, 0);
|
|
9030
|
+
|
|
9031
|
+
mm.run(sB, sA, cT);
|
|
9032
|
+
#endif
|
|
8028
9033
|
}
|
|
8029
9034
|
|
|
8030
|
-
if (!FC_mul_mm_bc_out || (
|
|
9035
|
+
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
|
|
8031
9036
|
// if no bounds checks on the output are needed, we can directly write to device memory
|
|
9037
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
8032
9038
|
device float * C = (device float *) dst +
|
|
8033
|
-
|
|
8034
|
-
|
|
9039
|
+
r0 + \
|
|
9040
|
+
r1 * args.ne0 + im*args.ne1*args.ne0;
|
|
9041
|
+
|
|
9042
|
+
auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
|
|
9043
|
+
cT.store(tC);
|
|
9044
|
+
#else
|
|
9045
|
+
device float * C = (device float *) dst +
|
|
9046
|
+
(r0 + 32*(sgitg & 1)) + \
|
|
9047
|
+
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
|
8035
9048
|
|
|
8036
9049
|
for (short i = 0; i < 8; i++) {
|
|
8037
|
-
simdgroup_store(mc[i], C + 8
|
|
9050
|
+
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
|
|
8038
9051
|
}
|
|
9052
|
+
#endif
|
|
8039
9053
|
} else {
|
|
8040
9054
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
8041
9055
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8042
|
-
|
|
8043
|
-
|
|
9056
|
+
|
|
9057
|
+
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
|
9058
|
+
|
|
9059
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
9060
|
+
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
|
|
9061
|
+
cT.store(tC);
|
|
9062
|
+
#else
|
|
8044
9063
|
for (short i = 0; i < 8; i++) {
|
|
8045
|
-
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*
|
|
9064
|
+
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
|
8046
9065
|
}
|
|
9066
|
+
#endif
|
|
8047
9067
|
|
|
8048
9068
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8049
9069
|
|
|
8050
9070
|
if (sgitg == 0) {
|
|
8051
|
-
for (int j = tiitg; j <
|
|
8052
|
-
device float * D = (device float *) dst +
|
|
9071
|
+
for (int j = tiitg; j < nr1; j += NR1) {
|
|
9072
|
+
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
|
|
8053
9073
|
device float4 * D4 = (device float4 *) D;
|
|
8054
9074
|
|
|
8055
|
-
threadgroup float * C = temp_str + (j*
|
|
9075
|
+
threadgroup float * C = temp_str + (j*NR0);
|
|
8056
9076
|
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
|
8057
9077
|
|
|
8058
9078
|
int i = 0;
|
|
8059
|
-
for (; i <
|
|
9079
|
+
for (; i < nr0/4; i++) {
|
|
8060
9080
|
*(D4 + i) = *(C4 + i);
|
|
8061
9081
|
}
|
|
8062
9082
|
|
|
8063
9083
|
i *= 4;
|
|
8064
|
-
for (; i <
|
|
9084
|
+
for (; i < nr0; i++) {
|
|
8065
9085
|
*(D + i) = *(C + i);
|
|
8066
9086
|
}
|
|
8067
9087
|
}
|
|
@@ -8128,6 +9148,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
|
|
|
8128
9148
|
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
|
|
8129
9149
|
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
|
|
8130
9150
|
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
|
9151
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
|
|
8131
9152
|
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
|
8132
9153
|
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
|
8133
9154
|
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
|
@@ -8146,55 +9167,55 @@ kernel void kernel_mul_mm_id(
|
|
|
8146
9167
|
ushort tiitg[[thread_index_in_threadgroup]],
|
|
8147
9168
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
8148
9169
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
8149
|
-
|
|
8150
9170
|
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
|
8151
9171
|
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
|
8152
9172
|
|
|
8153
|
-
|
|
8154
|
-
|
|
9173
|
+
threadgroup float * sc = (threadgroup float *)(shmem);
|
|
9174
|
+
|
|
9175
|
+
constexpr int NR0 = 64;
|
|
9176
|
+
constexpr int NR1 = 32;
|
|
9177
|
+
|
|
9178
|
+
constexpr int NK = 32;
|
|
9179
|
+
constexpr int NL0 = NK/16;
|
|
9180
|
+
constexpr int NL1 = NK/8;
|
|
9181
|
+
|
|
8155
9182
|
const int im = tgpig.z; // expert
|
|
9183
|
+
const int r0 = tgpig.y*NR0;
|
|
9184
|
+
const int r1 = tgpig.x*NR1;
|
|
8156
9185
|
|
|
8157
9186
|
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
|
|
8158
9187
|
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
|
8159
9188
|
|
|
8160
9189
|
const int32_t neh1 = tpe_u32[im];
|
|
8161
9190
|
|
|
8162
|
-
if (r1
|
|
9191
|
+
if (r1 >= neh1) {
|
|
8163
9192
|
return;
|
|
8164
9193
|
}
|
|
8165
9194
|
|
|
8166
9195
|
// if this block is of 64x32 shape or smaller
|
|
8167
|
-
const short
|
|
8168
|
-
const short
|
|
9196
|
+
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
|
|
9197
|
+
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
|
|
8169
9198
|
|
|
8170
9199
|
// a thread shouldn't load data outside of the matrix
|
|
8171
|
-
const short
|
|
8172
|
-
const short
|
|
8173
|
-
|
|
8174
|
-
S0_8x8 ma[4];
|
|
8175
|
-
S1_8x8 mb[2];
|
|
9200
|
+
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
|
|
9201
|
+
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
|
|
8176
9202
|
|
|
8177
|
-
|
|
9203
|
+
const short il0 = (tiitg % NL0);
|
|
8178
9204
|
|
|
8179
|
-
|
|
8180
|
-
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
8181
|
-
}
|
|
9205
|
+
short il = il0;
|
|
8182
9206
|
|
|
8183
|
-
|
|
8184
|
-
|
|
8185
|
-
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
|
|
9207
|
+
const int id = ids_i32[im*args.ne21 + r1 + lr1];
|
|
8186
9208
|
|
|
8187
9209
|
const short i11 = (id % args.ne20) % args.ne11;
|
|
8188
9210
|
const short i12 = (id / args.ne20);
|
|
8189
9211
|
const short i13 = 0;
|
|
8190
9212
|
|
|
8191
9213
|
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
|
|
8192
|
-
const short offset1 =
|
|
9214
|
+
const short offset1 = il0/nl;
|
|
8193
9215
|
|
|
8194
|
-
device const block_q * x = (device const block_q *)(src0
|
|
8195
|
-
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
|
9216
|
+
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
|
8196
9217
|
|
|
8197
|
-
const short iy =
|
|
9218
|
+
const short iy = 8*(tiitg % NL1);
|
|
8198
9219
|
|
|
8199
9220
|
device const T1 * y = (device const T1 *)(src1
|
|
8200
9221
|
+ args.nb13*i13
|
|
@@ -8202,16 +9223,45 @@ kernel void kernel_mul_mm_id(
|
|
|
8202
9223
|
+ args.nb11*i11
|
|
8203
9224
|
+ args.nb10*iy);
|
|
8204
9225
|
|
|
8205
|
-
|
|
9226
|
+
#ifndef GGML_METAL_HAS_TENSOR
|
|
9227
|
+
S0_8x8 ma[4];
|
|
9228
|
+
S1_8x8 mb[2];
|
|
9229
|
+
|
|
9230
|
+
simdgroup_float8x8 mc[8];
|
|
9231
|
+
|
|
9232
|
+
for (short i = 0; i < 8; i++){
|
|
9233
|
+
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
9234
|
+
}
|
|
9235
|
+
#else
|
|
9236
|
+
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
|
|
9237
|
+
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
|
|
9238
|
+
|
|
9239
|
+
mpp::tensor_ops::matmul2d<
|
|
9240
|
+
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
|
9241
|
+
execution_simdgroups<4>> mm;
|
|
9242
|
+
|
|
9243
|
+
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
|
|
9244
|
+
#endif
|
|
9245
|
+
|
|
9246
|
+
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
|
9247
|
+
#ifndef GGML_METAL_HAS_TENSOR
|
|
8206
9248
|
// load data and store to threadgroup memory
|
|
8207
9249
|
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
8208
9250
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8209
9251
|
|
|
8210
9252
|
// no need for dequantization
|
|
8211
9253
|
for (short i = 0; i < 16; i++) {
|
|
8212
|
-
|
|
8213
|
-
|
|
8214
|
-
|
|
9254
|
+
const short sx = 2*il0 + i/8;
|
|
9255
|
+
const short sy = (tiitg/NL0)/8;
|
|
9256
|
+
|
|
9257
|
+
//const short lx = i%8;
|
|
9258
|
+
//const short ly = (tiitg/NL0)%8;
|
|
9259
|
+
const short lx = (tiitg/NL0)%8;
|
|
9260
|
+
const short ly = i%8;
|
|
9261
|
+
|
|
9262
|
+
const short ib = 8*sx + sy;
|
|
9263
|
+
|
|
9264
|
+
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
8215
9265
|
}
|
|
8216
9266
|
} else {
|
|
8217
9267
|
S0_4x4 temp_a;
|
|
@@ -8220,85 +9270,188 @@ kernel void kernel_mul_mm_id(
|
|
|
8220
9270
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8221
9271
|
|
|
8222
9272
|
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
8223
|
-
|
|
8224
|
-
|
|
8225
|
-
|
|
9273
|
+
const short sx = 2*il0 + i/8;
|
|
9274
|
+
const short sy = (tiitg/NL0)/8;
|
|
9275
|
+
|
|
9276
|
+
//const short lx = i%8;
|
|
9277
|
+
//const short ly = (tiitg/NL0)%8;
|
|
9278
|
+
const short lx = (tiitg/NL0)%8;
|
|
9279
|
+
const short ly = i%8;
|
|
9280
|
+
|
|
9281
|
+
const short ib = 8*sx + sy;
|
|
9282
|
+
|
|
9283
|
+
// NOTE: this is massively slower.. WTF?
|
|
9284
|
+
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
|
|
9285
|
+
|
|
9286
|
+
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
|
|
8226
9287
|
}
|
|
8227
9288
|
}
|
|
8228
9289
|
|
|
8229
9290
|
if (FC_mul_mm_bc_inp) {
|
|
8230
9291
|
for (short i = 0; i < 8; ++i) {
|
|
8231
|
-
|
|
9292
|
+
const short sx = (tiitg%NL1);
|
|
9293
|
+
const short sy = (tiitg/NL1)/8;
|
|
9294
|
+
|
|
9295
|
+
const short lx = i;
|
|
9296
|
+
const short ly = (tiitg/NL1)%8;
|
|
9297
|
+
//const short lx = (tiitg/NL1)%8;
|
|
9298
|
+
//const short ly = i;
|
|
9299
|
+
|
|
9300
|
+
const short ib = 4*sx + sy;
|
|
9301
|
+
|
|
9302
|
+
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
|
9303
|
+
}
|
|
9304
|
+
} else {
|
|
9305
|
+
const short sx = (tiitg%NL1);
|
|
9306
|
+
const short sy = (tiitg/NL1)/8;
|
|
9307
|
+
|
|
9308
|
+
const short dx = sx;
|
|
9309
|
+
const short dy = sy;
|
|
9310
|
+
|
|
9311
|
+
const short ly = (tiitg/NL1)%8;
|
|
9312
|
+
|
|
9313
|
+
const short ib = 4*sx + sy;
|
|
9314
|
+
|
|
9315
|
+
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
9316
|
+
}
|
|
9317
|
+
#else
|
|
9318
|
+
// load data and store to threadgroup memory
|
|
9319
|
+
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
|
9320
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9321
|
+
|
|
9322
|
+
// no need for dequantization
|
|
9323
|
+
for (short i = 0; i < 16; i++) {
|
|
9324
|
+
const short sx = 2*il0 + i/8;
|
|
9325
|
+
const short sy = (tiitg/NL0)/8;
|
|
9326
|
+
|
|
9327
|
+
const short lx = i%8;
|
|
9328
|
+
const short ly = (tiitg/NL0)%8;
|
|
9329
|
+
//const short lx = (tiitg/NL0)%8;
|
|
9330
|
+
//const short ly = i%8;
|
|
9331
|
+
|
|
9332
|
+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
|
8232
9333
|
}
|
|
8233
9334
|
} else {
|
|
8234
|
-
|
|
9335
|
+
S0_4x4 temp_a;
|
|
9336
|
+
dequantize_func(x, il, temp_a);
|
|
9337
|
+
|
|
9338
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9339
|
+
|
|
9340
|
+
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
9341
|
+
const short sx = 2*il0 + i/8;
|
|
9342
|
+
const short sy = (tiitg/NL0)/8;
|
|
9343
|
+
|
|
9344
|
+
const short lx = i%8;
|
|
9345
|
+
const short ly = (tiitg/NL0)%8;
|
|
9346
|
+
//const short lx = (tiitg/NL0)%8;
|
|
9347
|
+
//const short ly = i%8;
|
|
9348
|
+
|
|
9349
|
+
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
|
|
9350
|
+
}
|
|
8235
9351
|
}
|
|
8236
9352
|
|
|
9353
|
+
if (FC_mul_mm_bc_inp) {
|
|
9354
|
+
for (short i = 0; i < 8; ++i) {
|
|
9355
|
+
const short sx = (tiitg%NL1);
|
|
9356
|
+
const short sy = (tiitg/NL1)/8;
|
|
9357
|
+
|
|
9358
|
+
const short lx = i;
|
|
9359
|
+
const short ly = (tiitg/NL1)%8;
|
|
9360
|
+
//const short lx = (tiitg/NL1)%8;
|
|
9361
|
+
//const short ly = i;
|
|
9362
|
+
|
|
9363
|
+
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
|
9364
|
+
}
|
|
9365
|
+
} else {
|
|
9366
|
+
const short sx = (tiitg%NL1);
|
|
9367
|
+
const short sy = (tiitg/NL1)/8;
|
|
9368
|
+
|
|
9369
|
+
//const short lx = i;
|
|
9370
|
+
const short ly = (tiitg/NL1)%8;
|
|
9371
|
+
//const short lx = (tiitg/NL1)%8;
|
|
9372
|
+
//const short ly = i;
|
|
9373
|
+
|
|
9374
|
+
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
|
9375
|
+
}
|
|
9376
|
+
#endif
|
|
9377
|
+
|
|
8237
9378
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
8238
9379
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
|
8239
|
-
|
|
9380
|
+
|
|
9381
|
+
y += NK;
|
|
8240
9382
|
|
|
8241
9383
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8242
9384
|
|
|
9385
|
+
#ifndef GGML_METAL_HAS_TENSOR
|
|
8243
9386
|
// load matrices from threadgroup memory and conduct outer products
|
|
8244
|
-
threadgroup const S0 * lsma = (sa +
|
|
8245
|
-
threadgroup const S1 * lsmb = (sb +
|
|
8246
|
-
|
|
8247
|
-
|
|
8248
|
-
|
|
8249
|
-
|
|
8250
|
-
|
|
8251
|
-
simdgroup_load(ma[i], lsma +
|
|
9387
|
+
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
|
9388
|
+
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
|
9389
|
+
|
|
9390
|
+
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
|
9391
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
9392
|
+
|
|
9393
|
+
FOR_UNROLL (short i = 0; i < 4; i++) {
|
|
9394
|
+
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
|
8252
9395
|
}
|
|
8253
9396
|
|
|
8254
9397
|
simdgroup_barrier(mem_flags::mem_none);
|
|
8255
9398
|
|
|
8256
|
-
|
|
8257
|
-
|
|
8258
|
-
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
|
9399
|
+
FOR_UNROLL (short i = 0; i < 2; i++) {
|
|
9400
|
+
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
|
8259
9401
|
}
|
|
8260
9402
|
|
|
8261
|
-
|
|
8262
|
-
|
|
9403
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
9404
|
+
|
|
9405
|
+
FOR_UNROLL (short i = 0; i < 8; i++){
|
|
8263
9406
|
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
|
8264
9407
|
}
|
|
8265
9408
|
|
|
8266
|
-
lsma +=
|
|
8267
|
-
lsmb +=
|
|
9409
|
+
lsma += 8*64;
|
|
9410
|
+
lsmb += 4*64;
|
|
8268
9411
|
}
|
|
9412
|
+
#else
|
|
9413
|
+
auto sA = tA.slice(0, 0);
|
|
9414
|
+
auto sB = tB.slice(0, 0);
|
|
9415
|
+
|
|
9416
|
+
mm.run(sB, sA, cT);
|
|
9417
|
+
#endif
|
|
8269
9418
|
}
|
|
8270
9419
|
|
|
9420
|
+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
8271
9421
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8272
9422
|
|
|
8273
|
-
|
|
8274
|
-
|
|
9423
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
9424
|
+
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
|
|
9425
|
+
cT.store(tC);
|
|
9426
|
+
#else
|
|
9427
|
+
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
|
8275
9428
|
|
|
8276
|
-
#pragma unroll(8)
|
|
8277
9429
|
for (short i = 0; i < 8; i++) {
|
|
8278
|
-
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*
|
|
9430
|
+
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
|
8279
9431
|
}
|
|
9432
|
+
#endif
|
|
8280
9433
|
|
|
8281
9434
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
8282
9435
|
|
|
8283
|
-
for (short j = sgitg; j <
|
|
8284
|
-
const int id = ids_i32[im*args.ne21 + r1
|
|
9436
|
+
for (short j = sgitg; j < nr1; j += 4) {
|
|
9437
|
+
const int id = ids_i32[im*args.ne21 + r1 + j];
|
|
8285
9438
|
|
|
8286
9439
|
const short ide = id % args.ne20;
|
|
8287
9440
|
const short idt = id / args.ne20;
|
|
8288
9441
|
|
|
8289
|
-
device float * D = (device float *) dst +
|
|
9442
|
+
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
|
|
8290
9443
|
device float4 * D4 = (device float4 *) D;
|
|
8291
9444
|
|
|
8292
|
-
threadgroup float * C = (threadgroup float *) shmem +
|
|
9445
|
+
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
|
|
8293
9446
|
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
|
8294
9447
|
|
|
8295
9448
|
int i = tiisg;
|
|
8296
|
-
for (; i <
|
|
9449
|
+
for (; i < nr0/4; i += 32) {
|
|
8297
9450
|
*(D4 + i) = *(C4 + i);
|
|
8298
9451
|
}
|
|
8299
9452
|
|
|
8300
|
-
i = (4*(
|
|
8301
|
-
for (; i <
|
|
9453
|
+
i = (4*(nr0/4)) + tiisg;
|
|
9454
|
+
for (; i < nr0; i += 32) {
|
|
8302
9455
|
*(D + i) = *(C + i);
|
|
8303
9456
|
}
|
|
8304
9457
|
}
|
|
@@ -8310,12 +9463,13 @@ kernel void kernel_mul_mm_id(
|
|
|
8310
9463
|
// get rows
|
|
8311
9464
|
//
|
|
8312
9465
|
|
|
8313
|
-
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
|
9466
|
+
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
|
|
8314
9467
|
|
|
8315
|
-
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
|
8316
|
-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
|
9468
|
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
|
|
9469
|
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
|
|
9470
|
+
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
|
|
8317
9471
|
#if defined(GGML_METAL_HAS_BF16)
|
|
8318
|
-
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
|
|
9472
|
+
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
|
|
8319
9473
|
#endif
|
|
8320
9474
|
|
|
8321
9475
|
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
|
@@ -8405,9 +9559,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
|
|
8405
9559
|
|
|
8406
9560
|
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
|
8407
9561
|
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
|
8408
|
-
#if defined(GGML_METAL_HAS_BF16)
|
|
8409
|
-
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
|
8410
|
-
#endif
|
|
8411
9562
|
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
|
8412
9563
|
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
|
8413
9564
|
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
|
@@ -8463,9 +9614,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
|
|
|
8463
9614
|
|
|
8464
9615
|
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
|
8465
9616
|
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
|
8466
|
-
#if defined(GGML_METAL_HAS_BF16)
|
|
8467
|
-
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
|
8468
|
-
#endif
|
|
8469
9617
|
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
|
8470
9618
|
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
|
8471
9619
|
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
|
@@ -8720,3 +9868,123 @@ kernel void kernel_pool_2d_avg_f32(
|
|
|
8720
9868
|
|
|
8721
9869
|
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
|
8722
9870
|
}
|
|
9871
|
+
|
|
9872
|
+
kernel void kernel_opt_step_adamw_f32(
|
|
9873
|
+
constant ggml_metal_kargs_opt_step_adamw & args,
|
|
9874
|
+
device float * x,
|
|
9875
|
+
device const float * g,
|
|
9876
|
+
device float * g_m,
|
|
9877
|
+
device float * g_v,
|
|
9878
|
+
device const float * pars,
|
|
9879
|
+
uint gid[[thread_position_in_grid]]) {
|
|
9880
|
+
|
|
9881
|
+
if (gid >= args.np) {
|
|
9882
|
+
return;
|
|
9883
|
+
}
|
|
9884
|
+
|
|
9885
|
+
const float alpha = pars[0];
|
|
9886
|
+
const float beta1 = pars[1];
|
|
9887
|
+
const float beta2 = pars[2];
|
|
9888
|
+
const float eps = pars[3];
|
|
9889
|
+
const float wd = pars[4];
|
|
9890
|
+
const float beta1h = pars[5];
|
|
9891
|
+
const float beta2h = pars[6];
|
|
9892
|
+
|
|
9893
|
+
const float gi = g[gid];
|
|
9894
|
+
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
|
|
9895
|
+
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
|
|
9896
|
+
|
|
9897
|
+
g_m[gid] = gmi;
|
|
9898
|
+
g_v[gid] = gvi;
|
|
9899
|
+
|
|
9900
|
+
const float mh = gmi * beta1h;
|
|
9901
|
+
const float vh = sqrt(gvi * beta2h) + eps;
|
|
9902
|
+
|
|
9903
|
+
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
|
|
9904
|
+
}
|
|
9905
|
+
|
|
9906
|
+
kernel void kernel_opt_step_sgd_f32(
|
|
9907
|
+
constant ggml_metal_kargs_opt_step_sgd & args,
|
|
9908
|
+
device float * x,
|
|
9909
|
+
device const float * g,
|
|
9910
|
+
device const float * pars,
|
|
9911
|
+
uint gid[[thread_position_in_grid]]) {
|
|
9912
|
+
|
|
9913
|
+
if (gid >= args.np) {
|
|
9914
|
+
return;
|
|
9915
|
+
}
|
|
9916
|
+
|
|
9917
|
+
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
|
|
9918
|
+
}
|
|
9919
|
+
|
|
9920
|
+
template<typename T>
|
|
9921
|
+
kernel void kernel_memset(
|
|
9922
|
+
constant ggml_metal_kargs_fill & args,
|
|
9923
|
+
device T * dst,
|
|
9924
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
9925
|
+
dst[tpig] = args.val;
|
|
9926
|
+
}
|
|
9927
|
+
|
|
9928
|
+
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
|
|
9929
|
+
|
|
9930
|
+
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
|
|
9931
|
+
|
|
9932
|
+
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
|
|
9933
|
+
|
|
9934
|
+
template<typename T>
|
|
9935
|
+
kernel void kernel_count_equal(
|
|
9936
|
+
constant ggml_metal_kargs_count_equal & args,
|
|
9937
|
+
device const char * src0,
|
|
9938
|
+
device const char * src1,
|
|
9939
|
+
device atomic_int * dst,
|
|
9940
|
+
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
|
9941
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
9942
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
9943
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
9944
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
9945
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
9946
|
+
const short NSG = FC_count_equal_nsg;
|
|
9947
|
+
|
|
9948
|
+
const int i3 = tgpig.z;
|
|
9949
|
+
const int i2 = tgpig.y;
|
|
9950
|
+
const int i1 = tgpig.x;
|
|
9951
|
+
|
|
9952
|
+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
|
9953
|
+
return;
|
|
9954
|
+
}
|
|
9955
|
+
|
|
9956
|
+
int sum = 0;
|
|
9957
|
+
|
|
9958
|
+
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
|
|
9959
|
+
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
|
|
9960
|
+
|
|
9961
|
+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
9962
|
+
const T v0 = *(device const T *)(base0 + i0*args.nb00);
|
|
9963
|
+
const T v1 = *(device const T *)(base1 + i0*args.nb10);
|
|
9964
|
+
sum += (v0 == v1);
|
|
9965
|
+
}
|
|
9966
|
+
|
|
9967
|
+
sum = simd_sum(sum);
|
|
9968
|
+
|
|
9969
|
+
if (tiisg == 0) {
|
|
9970
|
+
shmem_i32[sgitg] = sum;
|
|
9971
|
+
}
|
|
9972
|
+
|
|
9973
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
9974
|
+
|
|
9975
|
+
if (sgitg == 0) {
|
|
9976
|
+
float v = 0.0f;
|
|
9977
|
+
if (tpitg.x < NSG) {
|
|
9978
|
+
v = shmem_i32[tpitg.x];
|
|
9979
|
+
}
|
|
9980
|
+
|
|
9981
|
+
float total = simd_sum(v);
|
|
9982
|
+
if (tpitg.x == 0) {
|
|
9983
|
+
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
|
|
9984
|
+
}
|
|
9985
|
+
}
|
|
9986
|
+
}
|
|
9987
|
+
|
|
9988
|
+
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
|
|
9989
|
+
|
|
9990
|
+
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
|