whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
|
|
|
77
77
|
return x*y;
|
|
78
78
|
}
|
|
79
79
|
|
|
80
|
+
static inline float sum(float x) {
|
|
81
|
+
return x;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
static inline float sum(float4 x) {
|
|
85
|
+
return x[0] + x[1] + x[2] + x[3];
|
|
86
|
+
}
|
|
87
|
+
|
|
80
88
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
81
89
|
template <typename type4x4>
|
|
82
90
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
@@ -895,752 +903,432 @@ enum ggml_sort_order {
|
|
|
895
903
|
GGML_SORT_ORDER_DESC,
|
|
896
904
|
};
|
|
897
905
|
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
kernel void kernel_add_fuse_impl(
|
|
903
|
-
constant ggml_metal_kargs_bin & args,
|
|
904
|
-
device const char * src0,
|
|
905
|
-
device const char * src1,
|
|
906
|
-
device char * dst,
|
|
907
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
908
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
909
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
910
|
-
const int i03 = tgpig.z;
|
|
911
|
-
const int i02 = tgpig.y;
|
|
912
|
-
const int i01 = tgpig.x;
|
|
906
|
+
constant float GELU_COEF_A = 0.044715f;
|
|
907
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
|
908
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
909
|
+
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
|
913
910
|
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
911
|
+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
|
912
|
+
// ref: https://www.johndcook.com/blog/python_erf/
|
|
913
|
+
constant float p_erf = 0.3275911f;
|
|
914
|
+
constant float a1_erf = 0.254829592f;
|
|
915
|
+
constant float a2_erf = -0.284496736f;
|
|
916
|
+
constant float a3_erf = 1.421413741f;
|
|
917
|
+
constant float a4_erf = -1.453152027f;
|
|
918
|
+
constant float a5_erf = 1.061405429f;
|
|
917
919
|
|
|
918
|
-
|
|
919
|
-
|
|
920
|
+
template<typename T>
|
|
921
|
+
inline T erf_approx(T x) {
|
|
922
|
+
T sign_x = sign(x);
|
|
923
|
+
x = fabs(x);
|
|
924
|
+
T t = 1.0f / (1.0f + p_erf * x);
|
|
925
|
+
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
|
926
|
+
return sign_x * y;
|
|
927
|
+
}
|
|
920
928
|
|
|
921
|
-
|
|
922
|
-
for (short j = 0; j < F; ++j) {
|
|
923
|
-
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
924
|
-
}
|
|
929
|
+
template<typename T> T elu_approx(T x);
|
|
925
930
|
|
|
926
|
-
|
|
927
|
-
|
|
931
|
+
template<> inline float elu_approx<float>(float x) {
|
|
932
|
+
return (x > 0.f) ? x : (exp(x) - 1);
|
|
933
|
+
}
|
|
928
934
|
|
|
929
|
-
|
|
935
|
+
template<> inline float4 elu_approx<float4>(float4 x) {
|
|
936
|
+
float4 res;
|
|
930
937
|
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
938
|
+
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
|
939
|
+
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
|
940
|
+
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
|
941
|
+
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
|
935
942
|
|
|
936
|
-
|
|
937
|
-
}
|
|
943
|
+
return res;
|
|
938
944
|
}
|
|
939
945
|
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
|
943
|
-
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
|
944
|
-
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
|
945
|
-
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
|
946
|
-
template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
|
|
947
|
-
template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
|
|
948
|
-
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
|
949
|
-
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
|
946
|
+
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
|
|
947
|
+
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
|
|
950
948
|
|
|
951
|
-
|
|
952
|
-
|
|
949
|
+
template <typename T0, typename T, typename TC>
|
|
950
|
+
kernel void kernel_unary_impl(
|
|
951
|
+
constant ggml_metal_kargs_unary & args,
|
|
953
952
|
device const char * src0,
|
|
954
|
-
device const char * src1,
|
|
955
953
|
device char * dst,
|
|
956
954
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
957
955
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
958
956
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
const int i01 = tgpig.x;
|
|
962
|
-
|
|
963
|
-
const int i13 = i03%args.ne13;
|
|
964
|
-
const int i12 = i02%args.ne12;
|
|
965
|
-
const int i11 = i01%args.ne11;
|
|
957
|
+
#define FC_OP FC_unary_op
|
|
958
|
+
#define FC_CNT FC_unary_cnt
|
|
966
959
|
|
|
967
|
-
device const
|
|
968
|
-
device
|
|
969
|
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
960
|
+
device const T0 * src0_ptr;
|
|
961
|
+
device T * dst_ptr;
|
|
970
962
|
|
|
971
|
-
|
|
972
|
-
const int i10 = i0%args.ne10;
|
|
973
|
-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
|
|
974
|
-
}
|
|
975
|
-
}
|
|
963
|
+
int i0;
|
|
976
964
|
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
device const char * src0,
|
|
980
|
-
device const char * src1,
|
|
981
|
-
device char * dst,
|
|
982
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
983
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
984
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
985
|
-
const int i03 = tgpig.z;
|
|
986
|
-
const int i02 = tgpig.y;
|
|
987
|
-
const int i01 = tgpig.x;
|
|
965
|
+
if (FC_CNT) {
|
|
966
|
+
i0 = tgpig.x;
|
|
988
967
|
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
968
|
+
src0_ptr = (device const T0 *) (src0);
|
|
969
|
+
dst_ptr = (device T *) (dst);
|
|
970
|
+
} else {
|
|
971
|
+
const int i03 = tgpig.z;
|
|
972
|
+
const int i02 = tgpig.y;
|
|
973
|
+
const int k0 = tgpig.x/args.ne01;
|
|
974
|
+
const int i01 = tgpig.x - k0*args.ne01;
|
|
992
975
|
|
|
993
|
-
|
|
994
|
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
995
|
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
976
|
+
i0 = k0*ntg.x + tpitg.x;
|
|
996
977
|
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1000
|
-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
|
1001
|
-
}
|
|
1002
|
-
} else {
|
|
1003
|
-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1004
|
-
const int i10 = i0%args.ne10;
|
|
1005
|
-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
|
1006
|
-
}
|
|
978
|
+
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
979
|
+
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
|
|
1007
980
|
}
|
|
1008
|
-
}
|
|
1009
981
|
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
device const char * src0,
|
|
1013
|
-
device const char * src1,
|
|
1014
|
-
device char * dst,
|
|
1015
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1016
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1017
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1018
|
-
const int i03 = tgpig.z;
|
|
1019
|
-
const int i02 = tgpig.y;
|
|
1020
|
-
const int i01 = tgpig.x;
|
|
982
|
+
{
|
|
983
|
+
//threadgroup_barrier(mem_flags::mem_none);
|
|
1021
984
|
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
985
|
+
if (!FC_CNT) {
|
|
986
|
+
if (i0 >= args.ne0) {
|
|
987
|
+
return;
|
|
988
|
+
}
|
|
989
|
+
}
|
|
1025
990
|
|
|
1026
|
-
|
|
1027
|
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
1028
|
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
991
|
+
const TC x = (TC) src0_ptr[i0];
|
|
1029
992
|
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1033
|
-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
|
993
|
+
if (FC_OP == OP_UNARY_NUM_SCALE) {
|
|
994
|
+
dst_ptr[i0] = (T) (args.scale * x + args.bias);
|
|
1034
995
|
}
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
|
996
|
+
|
|
997
|
+
if (FC_OP == OP_UNARY_NUM_FILL) {
|
|
998
|
+
dst_ptr[i0] = (T) args.val;
|
|
1039
999
|
}
|
|
1040
|
-
}
|
|
1041
|
-
}
|
|
1042
1000
|
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
device const char * src1,
|
|
1047
|
-
device const char * src2,
|
|
1048
|
-
device char * dst,
|
|
1049
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1050
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1051
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1052
|
-
const int i1 = tgpig.x;
|
|
1053
|
-
const int i2 = tgpig.y;
|
|
1001
|
+
if (FC_OP == OP_UNARY_NUM_CLAMP) {
|
|
1002
|
+
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
|
|
1003
|
+
}
|
|
1054
1004
|
|
|
1055
|
-
|
|
1005
|
+
if (FC_OP == OP_UNARY_NUM_SQR) {
|
|
1006
|
+
dst_ptr[i0] = (T) (x * x);
|
|
1007
|
+
}
|
|
1056
1008
|
|
|
1057
|
-
|
|
1058
|
-
|
|
1009
|
+
if (FC_OP == OP_UNARY_NUM_SQRT) {
|
|
1010
|
+
dst_ptr[i0] = (T) sqrt(x);
|
|
1011
|
+
}
|
|
1059
1012
|
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1013
|
+
if (FC_OP == OP_UNARY_NUM_SIN) {
|
|
1014
|
+
dst_ptr[i0] = (T) sin(x);
|
|
1015
|
+
}
|
|
1063
1016
|
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
}
|
|
1017
|
+
if (FC_OP == OP_UNARY_NUM_COS) {
|
|
1018
|
+
dst_ptr[i0] = (T) cos(x);
|
|
1019
|
+
}
|
|
1068
1020
|
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
device const char * src0,
|
|
1073
|
-
device char * dst,
|
|
1074
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1075
|
-
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1076
|
-
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1077
|
-
const int i3 = tgpig.z;
|
|
1078
|
-
const int i2 = tgpig.y;
|
|
1079
|
-
const int i1 = tgpig.x;
|
|
1021
|
+
if (FC_OP == OP_UNARY_NUM_LOG) {
|
|
1022
|
+
dst_ptr[i0] = (T) log(x);
|
|
1023
|
+
}
|
|
1080
1024
|
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1025
|
+
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
|
|
1026
|
+
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
|
|
1027
|
+
}
|
|
1084
1028
|
|
|
1085
|
-
|
|
1086
|
-
|
|
1029
|
+
if (FC_OP == OP_UNARY_NUM_TANH) {
|
|
1030
|
+
dst_ptr[i0] = (T) precise::tanh(x);
|
|
1031
|
+
}
|
|
1087
1032
|
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
}
|
|
1092
|
-
}
|
|
1033
|
+
if (FC_OP == OP_UNARY_NUM_RELU) {
|
|
1034
|
+
dst_ptr[i0] = (T) fmax(0, x);
|
|
1035
|
+
}
|
|
1093
1036
|
|
|
1094
|
-
|
|
1037
|
+
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
|
|
1038
|
+
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
|
|
1039
|
+
}
|
|
1095
1040
|
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
|
1041
|
+
if (FC_OP == OP_UNARY_NUM_GELU) {
|
|
1042
|
+
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
|
|
1043
|
+
}
|
|
1100
1044
|
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
kernel void kernel_add_row_c4_fuse_impl(
|
|
1105
|
-
constant ggml_metal_kargs_bin & args,
|
|
1106
|
-
device const char * src0,
|
|
1107
|
-
device const char * src1,
|
|
1108
|
-
device char * dst,
|
|
1109
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1110
|
-
const uint nb = args.ne00/4;
|
|
1111
|
-
const uint i = tpig % nb;
|
|
1045
|
+
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
|
|
1046
|
+
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
|
|
1047
|
+
}
|
|
1112
1048
|
|
|
1113
|
-
|
|
1114
|
-
|
|
1049
|
+
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
|
|
1050
|
+
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
|
|
1051
|
+
}
|
|
1115
1052
|
|
|
1116
|
-
|
|
1053
|
+
if (FC_OP == OP_UNARY_NUM_SILU) {
|
|
1054
|
+
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
|
|
1055
|
+
}
|
|
1117
1056
|
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
}
|
|
1057
|
+
if (FC_OP == OP_UNARY_NUM_ELU) {
|
|
1058
|
+
dst_ptr[i0] = (T) elu_approx(x);
|
|
1059
|
+
}
|
|
1122
1060
|
|
|
1123
|
-
|
|
1124
|
-
|
|
1061
|
+
if (FC_OP == OP_UNARY_NUM_NEG) {
|
|
1062
|
+
dst_ptr[i0] = (T) -x;
|
|
1063
|
+
}
|
|
1125
1064
|
|
|
1126
|
-
|
|
1065
|
+
if (FC_OP == OP_UNARY_NUM_ABS) {
|
|
1066
|
+
dst_ptr[i0] = (T) fabs(x);
|
|
1067
|
+
}
|
|
1127
1068
|
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
|
1132
|
-
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
|
|
1133
|
-
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
|
|
1134
|
-
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
|
|
1135
|
-
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
|
|
1069
|
+
if (FC_OP == OP_UNARY_NUM_SGN) {
|
|
1070
|
+
dst_ptr[i0] = T(x > 0) - T(x < 0);
|
|
1071
|
+
}
|
|
1136
1072
|
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
device const char * src0,
|
|
1141
|
-
device const char * src1,
|
|
1142
|
-
device char * dst,
|
|
1143
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1073
|
+
if (FC_OP == OP_UNARY_NUM_STEP) {
|
|
1074
|
+
dst_ptr[i0] = T(x > 0);
|
|
1075
|
+
}
|
|
1144
1076
|
|
|
1145
|
-
|
|
1146
|
-
|
|
1077
|
+
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
|
|
1078
|
+
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
|
|
1079
|
+
}
|
|
1147
1080
|
|
|
1148
|
-
|
|
1149
|
-
|
|
1081
|
+
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
|
|
1082
|
+
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
|
|
1083
|
+
}
|
|
1150
1084
|
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
}
|
|
1085
|
+
if (FC_OP == OP_UNARY_NUM_EXP) {
|
|
1086
|
+
dst_ptr[i0] = (T) exp(x);
|
|
1087
|
+
}
|
|
1155
1088
|
|
|
1156
|
-
|
|
1089
|
+
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
|
|
1090
|
+
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
|
|
1091
|
+
}
|
|
1157
1092
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1093
|
+
if (FC_OP == OP_UNARY_NUM_EXPM1) {
|
|
1094
|
+
// TODO: precise implementation
|
|
1095
|
+
dst_ptr[i0] = (T) (exp(x) - 1);
|
|
1096
|
+
}
|
|
1161
1097
|
}
|
|
1162
1098
|
|
|
1163
|
-
|
|
1099
|
+
#undef FC_OP
|
|
1100
|
+
#undef FC_CNT
|
|
1164
1101
|
}
|
|
1165
1102
|
|
|
1166
|
-
typedef decltype(
|
|
1103
|
+
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
|
|
1167
1104
|
|
|
1168
|
-
template [[host_name("
|
|
1105
|
+
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
|
|
1106
|
+
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
|
|
1107
|
+
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
|
|
1108
|
+
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
|
|
1169
1109
|
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
device char * dst,
|
|
1176
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1177
|
-
|
|
1178
|
-
const uint nb = args.ne00/4;
|
|
1179
|
-
const uint i = tpig % nb;
|
|
1180
|
-
|
|
1181
|
-
device const float4 * src0_row = (device const float4 *) (src0);
|
|
1182
|
-
device float4 * dst_row = (device float4 *) (dst);
|
|
1110
|
+
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
|
|
1111
|
+
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
|
1112
|
+
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
|
1113
|
+
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
|
|
1114
|
+
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
|
|
1183
1115
|
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
1187
|
-
}
|
|
1188
|
-
|
|
1189
|
-
float4 res = src0_row[tpig];
|
|
1190
|
-
|
|
1191
|
-
#pragma unroll(F)
|
|
1192
|
-
for (short j = 0; j < F; ++j) {
|
|
1193
|
-
res *= src1_row[j][i];
|
|
1194
|
-
}
|
|
1195
|
-
|
|
1196
|
-
dst_row[tpig] = res;
|
|
1197
|
-
}
|
|
1198
|
-
|
|
1199
|
-
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
|
|
1200
|
-
|
|
1201
|
-
template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
|
1202
|
-
|
|
1203
|
-
template <short F>
|
|
1204
|
-
kernel void kernel_div_row_c4_fuse_impl(
|
|
1116
|
+
template <typename T0, typename T1, typename T>
|
|
1117
|
+
kernel void kernel_bin_fuse_impl(
|
|
1205
1118
|
constant ggml_metal_kargs_bin & args,
|
|
1206
1119
|
device const char * src0,
|
|
1207
1120
|
device const char * src1,
|
|
1208
1121
|
device char * dst,
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
device const float4 * src1_row[F];
|
|
1218
|
-
for (short j = 0; j < F; ++j) {
|
|
1219
|
-
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
1220
|
-
}
|
|
1221
|
-
|
|
1222
|
-
float4 res = src0_row[tpig];
|
|
1223
|
-
|
|
1224
|
-
#pragma unroll(F)
|
|
1225
|
-
for (short j = 0; j < F; ++j) {
|
|
1226
|
-
res /= src1_row[j][i];
|
|
1227
|
-
}
|
|
1228
|
-
|
|
1229
|
-
dst_row[tpig] = res;
|
|
1230
|
-
}
|
|
1231
|
-
|
|
1232
|
-
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
|
1233
|
-
|
|
1234
|
-
template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
|
1235
|
-
|
|
1236
|
-
kernel void kernel_scale_f32(
|
|
1237
|
-
constant ggml_metal_kargs_scale & args,
|
|
1238
|
-
device const float * src0,
|
|
1239
|
-
device float * dst,
|
|
1240
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1241
|
-
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
|
1242
|
-
}
|
|
1243
|
-
|
|
1244
|
-
kernel void kernel_scale_f32_4(
|
|
1245
|
-
constant ggml_metal_kargs_scale & args,
|
|
1246
|
-
device const float4 * src0,
|
|
1247
|
-
device float4 * dst,
|
|
1248
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1249
|
-
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
|
1250
|
-
}
|
|
1251
|
-
|
|
1252
|
-
kernel void kernel_fill_f32(
|
|
1253
|
-
constant ggml_metal_kargs_fill & args,
|
|
1254
|
-
device const float * src0,
|
|
1255
|
-
device float * dst,
|
|
1256
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1257
|
-
dst[tpig] = args.val;
|
|
1258
|
-
}
|
|
1259
|
-
|
|
1260
|
-
kernel void kernel_fill_f32_4(
|
|
1261
|
-
constant ggml_metal_kargs_fill & args,
|
|
1262
|
-
device const float4 * src0,
|
|
1263
|
-
device float4 * dst,
|
|
1264
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1265
|
-
dst[tpig] = args.val;
|
|
1266
|
-
}
|
|
1267
|
-
|
|
1268
|
-
kernel void kernel_clamp_f32(
|
|
1269
|
-
constant ggml_metal_kargs_clamp & args,
|
|
1270
|
-
device const float * src0,
|
|
1271
|
-
device float * dst,
|
|
1272
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1273
|
-
dst[tpig] = clamp(src0[tpig], args.min, args.max);
|
|
1274
|
-
}
|
|
1275
|
-
|
|
1276
|
-
kernel void kernel_clamp_f32_4(
|
|
1277
|
-
constant ggml_metal_kargs_clamp & args,
|
|
1278
|
-
device const float4 * src0,
|
|
1279
|
-
device float4 * dst,
|
|
1280
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1281
|
-
dst[tpig] = clamp(src0[tpig], args.min, args.max);
|
|
1282
|
-
}
|
|
1283
|
-
|
|
1284
|
-
kernel void kernel_relu_f32(
|
|
1285
|
-
device const float * src0,
|
|
1286
|
-
device float * dst,
|
|
1287
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1288
|
-
dst[tpig] = max(0.0f, src0[tpig]);
|
|
1289
|
-
}
|
|
1290
|
-
|
|
1291
|
-
kernel void kernel_relu_f32_4(
|
|
1292
|
-
device const float4 * src0,
|
|
1293
|
-
device float4 * dst,
|
|
1294
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1295
|
-
dst[tpig] = max(0.0f, src0[tpig]);
|
|
1296
|
-
}
|
|
1297
|
-
|
|
1298
|
-
kernel void kernel_sigmoid_f32(
|
|
1299
|
-
device const float * src0,
|
|
1300
|
-
device float * dst,
|
|
1301
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1302
|
-
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
|
1303
|
-
}
|
|
1304
|
-
|
|
1305
|
-
kernel void kernel_sigmoid_f32_4(
|
|
1306
|
-
device const float4 * src0,
|
|
1307
|
-
device float4 * dst,
|
|
1308
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1309
|
-
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
|
1310
|
-
}
|
|
1311
|
-
|
|
1312
|
-
kernel void kernel_tanh_f32(
|
|
1313
|
-
device const float * src0,
|
|
1314
|
-
device float * dst,
|
|
1315
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1316
|
-
dst[tpig] = precise::tanh(src0[tpig]);
|
|
1317
|
-
}
|
|
1318
|
-
|
|
1319
|
-
kernel void kernel_tanh_f32_4(
|
|
1320
|
-
device const float4 * src0,
|
|
1321
|
-
device float4 * dst,
|
|
1322
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1323
|
-
dst[tpig] = precise::tanh(src0[tpig]);
|
|
1324
|
-
}
|
|
1325
|
-
|
|
1326
|
-
constant float GELU_COEF_A = 0.044715f;
|
|
1327
|
-
constant float GELU_QUICK_COEF = -1.702f;
|
|
1328
|
-
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
1329
|
-
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
|
1330
|
-
|
|
1331
|
-
kernel void kernel_gelu_f32(
|
|
1332
|
-
device const float * src0,
|
|
1333
|
-
device float * dst,
|
|
1334
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1335
|
-
device const float & x = src0[tpig];
|
|
1336
|
-
|
|
1337
|
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
1338
|
-
}
|
|
1339
|
-
|
|
1340
|
-
kernel void kernel_gelu_f32_4(
|
|
1341
|
-
device const float4 * src0,
|
|
1342
|
-
device float4 * dst,
|
|
1343
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1344
|
-
device const float4 & x = src0[tpig];
|
|
1345
|
-
|
|
1346
|
-
// BEWARE !!!
|
|
1347
|
-
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
|
1348
|
-
// This was observed with Falcon 7B and 40B models
|
|
1349
|
-
//
|
|
1350
|
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
1351
|
-
}
|
|
1352
|
-
|
|
1353
|
-
kernel void kernel_gelu_quick_f32(
|
|
1354
|
-
device const float * src0,
|
|
1355
|
-
device float * dst,
|
|
1356
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1357
|
-
device const float & x = src0[tpig];
|
|
1358
|
-
|
|
1359
|
-
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
|
1360
|
-
}
|
|
1122
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1123
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1124
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1125
|
+
#define FC_OP FC_bin_op
|
|
1126
|
+
#define FC_F FC_bin_f
|
|
1127
|
+
#define FC_RB FC_bin_rb
|
|
1128
|
+
#define FC_CB FC_bin_cb
|
|
1361
1129
|
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
device const float4 & x = src0[tpig];
|
|
1130
|
+
if (FC_RB) {
|
|
1131
|
+
// row broadcast
|
|
1132
|
+
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
|
|
1133
|
+
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
|
|
1367
1134
|
|
|
1368
|
-
|
|
1369
|
-
|
|
1135
|
+
device const T0 * src0_row = (device const T0 *) (src0);
|
|
1136
|
+
device T * dst_row = (device T *) (dst);
|
|
1370
1137
|
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
constant float p_erf = 0.3275911f;
|
|
1374
|
-
constant float a1_erf = 0.254829592f;
|
|
1375
|
-
constant float a2_erf = -0.284496736f;
|
|
1376
|
-
constant float a3_erf = 1.421413741f;
|
|
1377
|
-
constant float a4_erf = -1.453152027f;
|
|
1378
|
-
constant float a5_erf = 1.061405429f;
|
|
1138
|
+
if (FC_F == 1) {
|
|
1139
|
+
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
|
|
1379
1140
|
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
x = fabs(x);
|
|
1384
|
-
T t = 1.0f / (1.0f + p_erf * x);
|
|
1385
|
-
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
|
1386
|
-
return sign_x * y;
|
|
1387
|
-
}
|
|
1141
|
+
if (FC_OP == 0) {
|
|
1142
|
+
dst_row[i0] = src0_row[i0] + src1_row[i1];
|
|
1143
|
+
}
|
|
1388
1144
|
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1393
|
-
device const float & x = src0[tpig];
|
|
1145
|
+
if (FC_OP == 1) {
|
|
1146
|
+
dst_row[i0] = src0_row[i0] - src1_row[i1];
|
|
1147
|
+
}
|
|
1394
1148
|
|
|
1395
|
-
|
|
1396
|
-
|
|
1149
|
+
if (FC_OP == 2) {
|
|
1150
|
+
dst_row[i0] = src0_row[i0] * src1_row[i1];
|
|
1151
|
+
}
|
|
1397
1152
|
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1153
|
+
if (FC_OP == 3) {
|
|
1154
|
+
dst_row[i0] = src0_row[i0] / src1_row[i1];
|
|
1155
|
+
}
|
|
1156
|
+
} else {
|
|
1157
|
+
T0 res = src0_row[i0];
|
|
1403
1158
|
|
|
1404
|
-
|
|
1405
|
-
|
|
1159
|
+
if (FC_OP == 0) {
|
|
1160
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1161
|
+
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
1162
|
+
}
|
|
1163
|
+
}
|
|
1406
1164
|
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
|
1413
|
-
}
|
|
1165
|
+
if (FC_OP == 1) {
|
|
1166
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1167
|
+
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
1168
|
+
}
|
|
1169
|
+
}
|
|
1414
1170
|
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
|
1421
|
-
}
|
|
1171
|
+
if (FC_OP == 2) {
|
|
1172
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1173
|
+
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
1174
|
+
}
|
|
1175
|
+
}
|
|
1422
1176
|
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
|
|
1429
|
-
}
|
|
1177
|
+
if (FC_OP == 3) {
|
|
1178
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1179
|
+
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
1180
|
+
}
|
|
1181
|
+
}
|
|
1430
1182
|
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
|
1438
|
-
dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
|
1439
|
-
dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
|
1440
|
-
}
|
|
1183
|
+
dst_row[i0] = res;
|
|
1184
|
+
}
|
|
1185
|
+
} else {
|
|
1186
|
+
const int i03 = tgpig.z;
|
|
1187
|
+
const int i02 = tgpig.y;
|
|
1188
|
+
const int i01 = tgpig.x;
|
|
1441
1189
|
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1446
|
-
dst[tpig] = src0[tpig] * src0[tpig];
|
|
1447
|
-
}
|
|
1190
|
+
if (i01 >= args.ne01) {
|
|
1191
|
+
return;
|
|
1192
|
+
}
|
|
1448
1193
|
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1453
|
-
dst[tpig] = src0[tpig] * src0[tpig];
|
|
1454
|
-
}
|
|
1194
|
+
const int i13 = i03%args.ne13;
|
|
1195
|
+
const int i12 = i02%args.ne12;
|
|
1196
|
+
const int i11 = i01%args.ne11;
|
|
1455
1197
|
|
|
1456
|
-
|
|
1457
|
-
device
|
|
1458
|
-
device float * dst,
|
|
1459
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1460
|
-
dst[tpig] = sqrt(src0[tpig]);
|
|
1461
|
-
}
|
|
1198
|
+
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
|
1199
|
+
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
|
1462
1200
|
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
device float4 * dst,
|
|
1466
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1467
|
-
dst[tpig] = sqrt(src0[tpig]);
|
|
1468
|
-
}
|
|
1201
|
+
if (FC_F == 1) {
|
|
1202
|
+
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
1469
1203
|
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
device float * dst,
|
|
1473
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1474
|
-
dst[tpig] = sin(src0[tpig]);
|
|
1475
|
-
}
|
|
1204
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1205
|
+
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
|
1476
1206
|
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1481
|
-
dst[tpig] = sin(src0[tpig]);
|
|
1482
|
-
}
|
|
1207
|
+
if (FC_OP == 0) {
|
|
1208
|
+
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
|
|
1209
|
+
}
|
|
1483
1210
|
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1488
|
-
dst[tpig] = cos(src0[tpig]);
|
|
1489
|
-
}
|
|
1211
|
+
if (FC_OP == 1) {
|
|
1212
|
+
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
|
|
1213
|
+
}
|
|
1490
1214
|
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1495
|
-
dst[tpig] = cos(src0[tpig]);
|
|
1496
|
-
}
|
|
1215
|
+
if (FC_OP == 2) {
|
|
1216
|
+
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
|
|
1217
|
+
}
|
|
1497
1218
|
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1219
|
+
if (FC_OP == 3) {
|
|
1220
|
+
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
|
|
1221
|
+
}
|
|
1222
|
+
}
|
|
1223
|
+
} else {
|
|
1224
|
+
device const T1 * src1_ptr[8];
|
|
1225
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1226
|
+
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
1227
|
+
}
|
|
1504
1228
|
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
device float4 * dst,
|
|
1508
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1509
|
-
dst[tpig] = log(src0[tpig]);
|
|
1510
|
-
}
|
|
1229
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1230
|
+
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
|
1511
1231
|
|
|
1512
|
-
|
|
1513
|
-
device const float * src0,
|
|
1514
|
-
device float * dst,
|
|
1515
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1516
|
-
dst[tpig] = -src0[tpig];
|
|
1517
|
-
}
|
|
1232
|
+
T res = src0_ptr[i0];
|
|
1518
1233
|
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
}
|
|
1234
|
+
if (FC_OP == 0) {
|
|
1235
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1236
|
+
res += src1_ptr[j][i10];
|
|
1237
|
+
}
|
|
1238
|
+
}
|
|
1525
1239
|
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
}
|
|
1240
|
+
if (FC_OP == 1) {
|
|
1241
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1242
|
+
res -= src1_ptr[j][i10];
|
|
1243
|
+
}
|
|
1244
|
+
}
|
|
1532
1245
|
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
}
|
|
1246
|
+
if (FC_OP == 2) {
|
|
1247
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1248
|
+
res *= src1_ptr[j][i10];
|
|
1249
|
+
}
|
|
1250
|
+
}
|
|
1539
1251
|
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
}
|
|
1252
|
+
if (FC_OP == 3) {
|
|
1253
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
1254
|
+
res /= src1_ptr[j][i10];
|
|
1255
|
+
}
|
|
1256
|
+
}
|
|
1546
1257
|
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
dst[tpig] = sign(src0[tpig]);
|
|
1552
|
-
}
|
|
1258
|
+
dst_ptr[i0] = res;
|
|
1259
|
+
}
|
|
1260
|
+
}
|
|
1261
|
+
}
|
|
1553
1262
|
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
dst[tpig] = step(0.0f, src0[tpig]);
|
|
1263
|
+
#undef FC_OP
|
|
1264
|
+
#undef FC_F
|
|
1265
|
+
#undef FC_RB
|
|
1266
|
+
#undef FC_CB
|
|
1559
1267
|
}
|
|
1560
1268
|
|
|
1561
|
-
|
|
1562
|
-
device const float4 * src0,
|
|
1563
|
-
device float4 * dst,
|
|
1564
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1565
|
-
dst[tpig] = step(0.0f, src0[tpig]);
|
|
1566
|
-
}
|
|
1269
|
+
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
|
|
1567
1270
|
|
|
1568
|
-
kernel
|
|
1569
|
-
|
|
1570
|
-
device float * dst,
|
|
1571
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1572
|
-
const float x = src0[tpig];
|
|
1573
|
-
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1574
|
-
}
|
|
1271
|
+
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
|
|
1272
|
+
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
|
|
1575
1273
|
|
|
1576
|
-
kernel void
|
|
1577
|
-
|
|
1578
|
-
device
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1274
|
+
kernel void kernel_add_id(
|
|
1275
|
+
constant ggml_metal_kargs_add_id & args,
|
|
1276
|
+
device const char * src0,
|
|
1277
|
+
device const char * src1,
|
|
1278
|
+
device const char * src2,
|
|
1279
|
+
device char * dst,
|
|
1280
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1281
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1282
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1283
|
+
const int i1 = tgpig.x;
|
|
1284
|
+
const int i2 = tgpig.y;
|
|
1583
1285
|
|
|
1584
|
-
|
|
1585
|
-
device const float * src0,
|
|
1586
|
-
device float * dst,
|
|
1587
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1588
|
-
const float x = src0[tpig];
|
|
1589
|
-
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1590
|
-
}
|
|
1286
|
+
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
|
|
1591
1287
|
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
device float4 * dst,
|
|
1595
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1596
|
-
const float4 x = src0[tpig];
|
|
1597
|
-
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1598
|
-
}
|
|
1288
|
+
const size_t nb1 = args.ne0 * sizeof(float);
|
|
1289
|
+
const size_t nb2 = args.ne1 * nb1;
|
|
1599
1290
|
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1604
|
-
dst[tpig] = exp(src0[tpig]);
|
|
1605
|
-
}
|
|
1291
|
+
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
|
1292
|
+
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
|
1293
|
+
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
|
1606
1294
|
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1611
|
-
dst[tpig] = exp(src0[tpig]);
|
|
1295
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1296
|
+
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
|
1297
|
+
}
|
|
1612
1298
|
}
|
|
1613
1299
|
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
|
|
1300
|
+
template<typename T>
|
|
1301
|
+
kernel void kernel_repeat(
|
|
1302
|
+
constant ggml_metal_kargs_repeat & args,
|
|
1303
|
+
device const char * src0,
|
|
1304
|
+
device char * dst,
|
|
1305
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1306
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1307
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1308
|
+
const int i3 = tgpig.z;
|
|
1309
|
+
const int i2 = tgpig.y;
|
|
1310
|
+
const int i1 = tgpig.x;
|
|
1621
1311
|
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1626
|
-
device const float4 & x = src0[tpig];
|
|
1627
|
-
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
|
1628
|
-
}
|
|
1312
|
+
const int i03 = i3%args.ne03;
|
|
1313
|
+
const int i02 = i2%args.ne02;
|
|
1314
|
+
const int i01 = i1%args.ne01;
|
|
1629
1315
|
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
device float * dst,
|
|
1633
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
1634
|
-
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
|
1635
|
-
}
|
|
1316
|
+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
|
1317
|
+
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
|
|
1636
1318
|
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
device
|
|
1640
|
-
|
|
1641
|
-
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
|
1319
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
1320
|
+
const int i00 = i0%args.ne00;
|
|
1321
|
+
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
|
|
1322
|
+
}
|
|
1642
1323
|
}
|
|
1643
1324
|
|
|
1325
|
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
|
1326
|
+
|
|
1327
|
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
|
1328
|
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
|
1329
|
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
|
1330
|
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
|
1331
|
+
|
|
1644
1332
|
kernel void kernel_reglu_f32(
|
|
1645
1333
|
constant ggml_metal_kargs_glu & args,
|
|
1646
1334
|
device const char * src0,
|
|
@@ -1824,33 +1512,35 @@ kernel void kernel_op_sum_f32(
|
|
|
1824
1512
|
}
|
|
1825
1513
|
}
|
|
1826
1514
|
|
|
1827
|
-
|
|
1828
|
-
|
|
1515
|
+
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
|
|
1516
|
+
|
|
1517
|
+
template <typename T0, typename T>
|
|
1518
|
+
kernel void kernel_sum_rows_impl(
|
|
1829
1519
|
constant ggml_metal_kargs_sum_rows & args,
|
|
1830
|
-
device const
|
|
1831
|
-
device
|
|
1832
|
-
threadgroup
|
|
1520
|
+
device const char * src0,
|
|
1521
|
+
device char * dst,
|
|
1522
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
1833
1523
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1834
1524
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1835
1525
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1836
1526
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1837
1527
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1838
|
-
|
|
1839
|
-
int64_t i2 = tgpig.y;
|
|
1840
|
-
int64_t i1 = tgpig.x;
|
|
1528
|
+
#define FC_OP FC_sum_rows_op
|
|
1841
1529
|
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1530
|
+
const int i3 = tgpig.z;
|
|
1531
|
+
const int i2 = tgpig.y;
|
|
1532
|
+
const int i1 = tgpig.x;
|
|
1533
|
+
|
|
1534
|
+
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
|
|
1845
1535
|
|
|
1846
1536
|
if (sgitg == 0) {
|
|
1847
|
-
|
|
1537
|
+
shmem_t[tiisg] = 0.0f;
|
|
1848
1538
|
}
|
|
1849
1539
|
|
|
1850
|
-
device const
|
|
1851
|
-
device
|
|
1540
|
+
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
|
1541
|
+
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
|
1852
1542
|
|
|
1853
|
-
|
|
1543
|
+
T0 sumf = T0(0.0f);
|
|
1854
1544
|
|
|
1855
1545
|
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
1856
1546
|
sumf += src_row[i0];
|
|
@@ -1861,23 +1551,33 @@ kernel void kernel_sum_rows(
|
|
|
1861
1551
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1862
1552
|
|
|
1863
1553
|
if (tiisg == 0) {
|
|
1864
|
-
|
|
1554
|
+
shmem_t[sgitg] = sumf;
|
|
1865
1555
|
}
|
|
1866
1556
|
|
|
1867
1557
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1868
1558
|
|
|
1869
|
-
sumf =
|
|
1559
|
+
sumf = shmem_t[tiisg];
|
|
1870
1560
|
sumf = simd_sum(sumf);
|
|
1871
1561
|
|
|
1872
1562
|
if (tpitg.x == 0) {
|
|
1873
|
-
|
|
1563
|
+
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
|
|
1564
|
+
if (is_same<float4, T0>::value) {
|
|
1565
|
+
dst_row[0] = sum(sumf) / (4*args.ne00);
|
|
1566
|
+
} else {
|
|
1567
|
+
dst_row[0] = sum(sumf) / args.ne00;
|
|
1568
|
+
}
|
|
1569
|
+
} else {
|
|
1570
|
+
dst_row[0] = sum(sumf);
|
|
1571
|
+
}
|
|
1874
1572
|
}
|
|
1573
|
+
|
|
1574
|
+
#undef FC_OP
|
|
1875
1575
|
}
|
|
1876
1576
|
|
|
1877
|
-
typedef decltype(
|
|
1577
|
+
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
|
|
1878
1578
|
|
|
1879
|
-
template [[host_name("
|
|
1880
|
-
template [[host_name("
|
|
1579
|
+
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
|
|
1580
|
+
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
|
|
1881
1581
|
|
|
1882
1582
|
template<typename T>
|
|
1883
1583
|
kernel void kernel_cumsum_blk(
|
|
@@ -2689,51 +2389,347 @@ kernel void kernel_rwkv_wkv7_f32(
|
|
|
2689
2389
|
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
|
2690
2390
|
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
|
2691
2391
|
|
|
2692
|
-
for (uint t = start_t; t < end_t; t += C) {
|
|
2693
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2694
|
-
_r[tid] = r[t];
|
|
2695
|
-
_w[tid] = w[t];
|
|
2696
|
-
_k[tid] = k[t];
|
|
2697
|
-
_a[tid] = a[t];
|
|
2698
|
-
_b[tid] = b[t];
|
|
2392
|
+
for (uint t = start_t; t < end_t; t += C) {
|
|
2393
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2394
|
+
_r[tid] = r[t];
|
|
2395
|
+
_w[tid] = w[t];
|
|
2396
|
+
_k[tid] = k[t];
|
|
2397
|
+
_a[tid] = a[t];
|
|
2398
|
+
_b[tid] = b[t];
|
|
2399
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2400
|
+
|
|
2401
|
+
const float v_val = v[t];
|
|
2402
|
+
float y = 0.0, sa = 0.0;
|
|
2403
|
+
|
|
2404
|
+
float4 sa_vec(0.0);
|
|
2405
|
+
|
|
2406
|
+
for (uint j = 0; j < head_size; j += 4) {
|
|
2407
|
+
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
|
2408
|
+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
2409
|
+
sa_vec += a_vec * s_vec;
|
|
2410
|
+
}
|
|
2411
|
+
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
|
2412
|
+
|
|
2413
|
+
for (uint j = 0; j < head_size; j += 4) {
|
|
2414
|
+
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
2415
|
+
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
|
2416
|
+
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
2417
|
+
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
|
2418
|
+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
2419
|
+
|
|
2420
|
+
float4 kv = k_vec * v_val;
|
|
2421
|
+
|
|
2422
|
+
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
|
2423
|
+
y += dot(s_vec, r_vec);
|
|
2424
|
+
|
|
2425
|
+
state[j] = s_vec[0];
|
|
2426
|
+
state[j+1] = s_vec[1];
|
|
2427
|
+
state[j+2] = s_vec[2];
|
|
2428
|
+
state[j+3] = s_vec[3];
|
|
2429
|
+
}
|
|
2430
|
+
|
|
2431
|
+
dst[t] = y;
|
|
2432
|
+
}
|
|
2433
|
+
|
|
2434
|
+
for (uint i = 0; i < head_size; i++) {
|
|
2435
|
+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
|
2436
|
+
+ tid * head_size + i] = state[i];
|
|
2437
|
+
}
|
|
2438
|
+
}
|
|
2439
|
+
|
|
2440
|
+
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
|
|
2441
|
+
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
|
|
2442
|
+
|
|
2443
|
+
#if 1
|
|
2444
|
+
template<short NSG>
|
|
2445
|
+
kernel void kernel_gated_delta_net_impl(
|
|
2446
|
+
constant ggml_metal_kargs_gated_delta_net & args,
|
|
2447
|
+
device const char * q,
|
|
2448
|
+
device const char * k,
|
|
2449
|
+
device const char * v,
|
|
2450
|
+
device const char * g,
|
|
2451
|
+
device const char * b,
|
|
2452
|
+
device const char * s,
|
|
2453
|
+
device char * dst,
|
|
2454
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2455
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2456
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2457
|
+
#define S_v FC_gated_delta_net_ne20
|
|
2458
|
+
#define G FC_gated_delta_net_ne30
|
|
2459
|
+
|
|
2460
|
+
const uint tx = tpitg.x;
|
|
2461
|
+
const uint ty = tpitg.y;
|
|
2462
|
+
|
|
2463
|
+
const uint i23 = tgpig.z; // B
|
|
2464
|
+
const uint i21 = tgpig.y; // H
|
|
2465
|
+
const uint i20 = tgpig.x*NSG + ty;
|
|
2466
|
+
|
|
2467
|
+
const uint i01 = i21 % args.ne01;
|
|
2468
|
+
const uint i11 = i21 % args.ne11;
|
|
2469
|
+
|
|
2470
|
+
const float scale = 1.0f / sqrt((float)S_v);
|
|
2471
|
+
|
|
2472
|
+
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
|
2473
|
+
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
|
2474
|
+
|
|
2475
|
+
float ls[NSG];
|
|
2476
|
+
|
|
2477
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2478
|
+
const short is = tx*NSG + j;
|
|
2479
|
+
ls[j] = s_ptr[is];
|
|
2480
|
+
}
|
|
2481
|
+
|
|
2482
|
+
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
|
2483
|
+
|
|
2484
|
+
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
|
2485
|
+
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
|
2486
|
+
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
|
2487
|
+
|
|
2488
|
+
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
|
2489
|
+
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
|
2490
|
+
|
|
2491
|
+
for (short t = 0; t < args.ne22; t++) {
|
|
2492
|
+
float s_k = 0.0f;
|
|
2493
|
+
|
|
2494
|
+
if (G == 1) {
|
|
2495
|
+
const float g_exp = exp(g_ptr[0]);
|
|
2496
|
+
|
|
2497
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2498
|
+
const short is = tx*NSG + j;
|
|
2499
|
+
ls[j] *= g_exp;
|
|
2500
|
+
|
|
2501
|
+
s_k += ls[j]*k_ptr[is];
|
|
2502
|
+
}
|
|
2503
|
+
} else {
|
|
2504
|
+
// KDA
|
|
2505
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2506
|
+
const short is = tx*NSG + j;
|
|
2507
|
+
ls[j] *= exp(g_ptr[is]);
|
|
2508
|
+
|
|
2509
|
+
s_k += ls[j]*k_ptr[is];
|
|
2510
|
+
}
|
|
2511
|
+
}
|
|
2512
|
+
|
|
2513
|
+
s_k = simd_sum(s_k);
|
|
2514
|
+
|
|
2515
|
+
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
|
2516
|
+
|
|
2517
|
+
float y = 0.0f;
|
|
2518
|
+
|
|
2519
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2520
|
+
const short is = tx*NSG + j;
|
|
2521
|
+
ls[j] += k_ptr[is]*d;
|
|
2522
|
+
|
|
2523
|
+
y += ls[j]*q_ptr[is];
|
|
2524
|
+
}
|
|
2525
|
+
|
|
2526
|
+
y = simd_sum(y);
|
|
2527
|
+
|
|
2528
|
+
if (tx == 0) {
|
|
2529
|
+
dst_attn[t*args.ne21*S_v] = y*scale;
|
|
2530
|
+
}
|
|
2531
|
+
|
|
2532
|
+
q_ptr += args.ns02;
|
|
2533
|
+
k_ptr += args.ns12;
|
|
2534
|
+
v_ptr += args.ns22;
|
|
2535
|
+
|
|
2536
|
+
b_ptr += args.ne21;
|
|
2537
|
+
g_ptr += args.ne21*G;
|
|
2538
|
+
}
|
|
2539
|
+
|
|
2540
|
+
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
|
2541
|
+
|
|
2542
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2543
|
+
const short is = tx*NSG + j;
|
|
2544
|
+
dst_state[is] = ls[j];
|
|
2545
|
+
}
|
|
2546
|
+
|
|
2547
|
+
#undef S_v
|
|
2548
|
+
#undef G
|
|
2549
|
+
}
|
|
2550
|
+
|
|
2551
|
+
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
|
|
2552
|
+
|
|
2553
|
+
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
|
|
2554
|
+
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
|
|
2555
|
+
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
|
|
2556
|
+
|
|
2557
|
+
#else
|
|
2558
|
+
// a simplified version of the above
|
|
2559
|
+
// no performance improvement, so keep the above version for now
|
|
2560
|
+
|
|
2561
|
+
template<typename T, short NSG>
|
|
2562
|
+
kernel void kernel_gated_delta_net_impl(
|
|
2563
|
+
constant ggml_metal_kargs_gated_delta_net & args,
|
|
2564
|
+
device const char * q,
|
|
2565
|
+
device const char * k,
|
|
2566
|
+
device const char * v,
|
|
2567
|
+
device const char * g,
|
|
2568
|
+
device const char * b,
|
|
2569
|
+
device const char * s,
|
|
2570
|
+
device char * dst,
|
|
2571
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2572
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2573
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2574
|
+
#define S_v FC_gated_delta_net_ne20
|
|
2575
|
+
#define G FC_gated_delta_net_ne30
|
|
2576
|
+
|
|
2577
|
+
const uint tx = tpitg.x;
|
|
2578
|
+
const uint ty = tpitg.y;
|
|
2579
|
+
|
|
2580
|
+
const uint i23 = tgpig.z; // B
|
|
2581
|
+
const uint i21 = tgpig.y; // H
|
|
2582
|
+
const uint i20 = tgpig.x*NSG + ty;
|
|
2583
|
+
|
|
2584
|
+
const uint i01 = i21 % args.ne01;
|
|
2585
|
+
const uint i11 = i21 % args.ne11;
|
|
2586
|
+
|
|
2587
|
+
const float scale = 1.0f / sqrt((float)S_v);
|
|
2588
|
+
|
|
2589
|
+
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
|
2590
|
+
|
|
2591
|
+
float lsf[NSG];
|
|
2592
|
+
|
|
2593
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2594
|
+
const short is = tx*NSG + j;
|
|
2595
|
+
lsf[j] = s_ptr[is*S_v];
|
|
2596
|
+
}
|
|
2597
|
+
|
|
2598
|
+
thread T * ls = (thread T *) (lsf);
|
|
2599
|
+
|
|
2600
|
+
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
|
2601
|
+
|
|
2602
|
+
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
|
2603
|
+
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
|
2604
|
+
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
|
2605
|
+
|
|
2606
|
+
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
|
2607
|
+
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
|
2608
|
+
|
|
2609
|
+
for (short t = 0; t < args.ne22; t++) {
|
|
2610
|
+
device const T * qt_ptr = (device const T *) (q_ptr);
|
|
2611
|
+
device const T * kt_ptr = (device const T *) (k_ptr);
|
|
2612
|
+
device const T * gt_ptr = (device const T *) (g_ptr);
|
|
2613
|
+
|
|
2614
|
+
if (G == 1) {
|
|
2615
|
+
*ls *= exp(g_ptr[0]);
|
|
2616
|
+
} else {
|
|
2617
|
+
// KDA
|
|
2618
|
+
*ls *= exp(gt_ptr[tx]);
|
|
2619
|
+
}
|
|
2620
|
+
|
|
2621
|
+
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
|
|
2622
|
+
|
|
2623
|
+
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
|
2624
|
+
|
|
2625
|
+
*ls += kt_ptr[tx]*d;
|
|
2626
|
+
|
|
2627
|
+
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
|
|
2628
|
+
|
|
2629
|
+
if (tx == 0) {
|
|
2630
|
+
*dst_attn = y*scale;
|
|
2631
|
+
}
|
|
2632
|
+
|
|
2633
|
+
q_ptr += args.ns02;
|
|
2634
|
+
k_ptr += args.ns12;
|
|
2635
|
+
v_ptr += args.ns22;
|
|
2636
|
+
|
|
2637
|
+
b_ptr += args.ne21;
|
|
2638
|
+
g_ptr += args.ne21*G;
|
|
2639
|
+
|
|
2640
|
+
dst_attn += args.ne21*S_v;
|
|
2641
|
+
}
|
|
2642
|
+
|
|
2643
|
+
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
|
2644
|
+
device T * dstt_state = (device T *) (dst_state);
|
|
2645
|
+
|
|
2646
|
+
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
|
2647
|
+
const short is = tx*NSG + j;
|
|
2648
|
+
dst_state[is*S_v] = lsf[j];
|
|
2649
|
+
}
|
|
2650
|
+
|
|
2651
|
+
#undef S_v
|
|
2652
|
+
#undef G
|
|
2653
|
+
}
|
|
2654
|
+
|
|
2655
|
+
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
|
|
2656
|
+
|
|
2657
|
+
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
|
|
2658
|
+
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
|
|
2659
|
+
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
|
|
2660
|
+
#endif
|
|
2661
|
+
|
|
2662
|
+
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
|
|
2663
|
+
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
|
|
2664
|
+
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
|
|
2665
|
+
|
|
2666
|
+
kernel void kernel_solve_tri_f32(
|
|
2667
|
+
constant ggml_metal_kargs_solve_tri & args,
|
|
2668
|
+
device const char * src0,
|
|
2669
|
+
device const char * src1,
|
|
2670
|
+
device char * dst,
|
|
2671
|
+
threadgroup char * shmem [[threadgroup(0)]],
|
|
2672
|
+
ushort3 tgpig[[threadgroup_position_in_grid]],
|
|
2673
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
2674
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
2675
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
2676
|
+
constexpr short NW = N_SIMDWIDTH;
|
|
2677
|
+
|
|
2678
|
+
const short NSG = FC_solve_tri_nsg;
|
|
2679
|
+
const short N = FC_solve_tri_n;
|
|
2680
|
+
const short K = FC_solve_tri_k;
|
|
2681
|
+
const short NP = PAD2(N, NW);
|
|
2682
|
+
|
|
2683
|
+
const int32_t i03 = tgpig.z;
|
|
2684
|
+
const int32_t i02 = tgpig.y;
|
|
2685
|
+
const int32_t i01 = tgpig.x*NSG + sgitg;
|
|
2686
|
+
|
|
2687
|
+
threadgroup float * sh0 = (threadgroup float *) shmem;
|
|
2688
|
+
|
|
2689
|
+
device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
|
|
2690
|
+
device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
|
|
2691
|
+
device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
|
|
2692
|
+
|
|
2693
|
+
for (short rr = 0; rr < N; rr += NSG) {
|
|
2699
2694
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2700
2695
|
|
|
2701
|
-
|
|
2702
|
-
|
|
2696
|
+
{
|
|
2697
|
+
threadgroup float * sh0_cur = sh0 + sgitg*NP;
|
|
2703
2698
|
|
|
2704
|
-
|
|
2699
|
+
for (short t = 0; t*NW < N; ++t) {
|
|
2700
|
+
const short idx = t*NW + tiisg;
|
|
2701
|
+
sh0_cur[idx] = src0_ptr[idx];
|
|
2702
|
+
}
|
|
2705
2703
|
|
|
2706
|
-
|
|
2707
|
-
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
|
2708
|
-
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
2709
|
-
sa_vec += a_vec * s_vec;
|
|
2704
|
+
src0_ptr += NSG*N;
|
|
2710
2705
|
}
|
|
2711
|
-
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
|
2712
2706
|
|
|
2713
|
-
|
|
2714
|
-
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
2715
|
-
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
|
2716
|
-
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
2717
|
-
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
|
2718
|
-
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
2707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2719
2708
|
|
|
2720
|
-
|
|
2709
|
+
if (i01 >= args.ne10) {
|
|
2710
|
+
continue;
|
|
2711
|
+
}
|
|
2721
2712
|
|
|
2722
|
-
|
|
2723
|
-
|
|
2713
|
+
for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
|
|
2714
|
+
const short r = rr + ir;
|
|
2724
2715
|
|
|
2725
|
-
|
|
2726
|
-
state[j+1] = s_vec[1];
|
|
2727
|
-
state[j+2] = s_vec[2];
|
|
2728
|
-
state[j+3] = s_vec[3];
|
|
2729
|
-
}
|
|
2716
|
+
threadgroup float * sh0_cur = sh0 + ir*NP;
|
|
2730
2717
|
|
|
2731
|
-
|
|
2732
|
-
}
|
|
2718
|
+
float sum = 0.0f;
|
|
2733
2719
|
|
|
2734
|
-
|
|
2735
|
-
|
|
2736
|
-
|
|
2720
|
+
for (short t = 0; t*NW < r; ++t) {
|
|
2721
|
+
const short idx = t*NW + tiisg;
|
|
2722
|
+
sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
|
|
2723
|
+
}
|
|
2724
|
+
|
|
2725
|
+
sum = simd_sum(sum);
|
|
2726
|
+
|
|
2727
|
+
if (tiisg == 0) {
|
|
2728
|
+
const float diag = sh0_cur[r];
|
|
2729
|
+
|
|
2730
|
+
dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
|
|
2731
|
+
}
|
|
2732
|
+
}
|
|
2737
2733
|
}
|
|
2738
2734
|
}
|
|
2739
2735
|
|
|
@@ -2970,26 +2966,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
|
|
|
2970
2966
|
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
|
2971
2967
|
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
|
2972
2968
|
|
|
2973
|
-
|
|
2969
|
+
template <typename T0, typename T>
|
|
2970
|
+
kernel void kernel_l2_norm_impl(
|
|
2974
2971
|
constant ggml_metal_kargs_l2_norm & args,
|
|
2975
2972
|
device const char * src0,
|
|
2976
2973
|
device char * dst,
|
|
2977
2974
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
|
2978
|
-
|
|
2979
|
-
|
|
2980
|
-
ushort
|
|
2981
|
-
ushort
|
|
2982
|
-
|
|
2975
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2976
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
2977
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
2978
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
2979
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
2980
|
+
const int i03 = tgpig.z;
|
|
2981
|
+
const int i02 = tgpig.y;
|
|
2982
|
+
const int i01 = tgpig.x;
|
|
2983
|
+
|
|
2983
2984
|
if (sgitg == 0) {
|
|
2984
2985
|
shmem_f32[tiisg] = 0.0f;
|
|
2985
2986
|
}
|
|
2986
2987
|
|
|
2987
|
-
device const
|
|
2988
|
+
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
2989
|
+
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
|
2988
2990
|
|
|
2989
2991
|
float sumf = 0.0f;
|
|
2990
2992
|
|
|
2991
2993
|
// parallel sum
|
|
2992
|
-
for (int i00 = tpitg; i00 < args.
|
|
2994
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
|
2993
2995
|
sumf += dot(x[i00], x[i00]);
|
|
2994
2996
|
}
|
|
2995
2997
|
sumf = simd_sum(sumf);
|
|
@@ -3005,14 +3007,18 @@ kernel void kernel_l2_norm_f32(
|
|
|
3005
3007
|
sumf = shmem_f32[tiisg];
|
|
3006
3008
|
sumf = simd_sum(sumf);
|
|
3007
3009
|
|
|
3008
|
-
const float scale = 1.0f/sqrt(
|
|
3010
|
+
const float scale = 1.0f/max(sqrt(sumf), args.eps);
|
|
3009
3011
|
|
|
3010
|
-
|
|
3011
|
-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
|
3012
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
|
3012
3013
|
y[i00] = x[i00] * scale;
|
|
3013
3014
|
}
|
|
3014
3015
|
}
|
|
3015
3016
|
|
|
3017
|
+
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
|
|
3018
|
+
|
|
3019
|
+
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
|
|
3020
|
+
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
|
|
3021
|
+
|
|
3016
3022
|
kernel void kernel_group_norm_f32(
|
|
3017
3023
|
constant ggml_metal_kargs_group_norm & args,
|
|
3018
3024
|
device const float * src0,
|
|
@@ -3700,6 +3706,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
|
|
|
3700
3706
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
|
|
3701
3707
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
|
|
3702
3708
|
|
|
3709
|
+
#if defined(GGML_METAL_HAS_BF16)
|
|
3710
|
+
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
|
|
3711
|
+
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
|
|
3712
|
+
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
|
|
3713
|
+
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
|
3714
|
+
#endif
|
|
3715
|
+
|
|
3703
3716
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
3704
3717
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
3705
3718
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
@@ -3750,6 +3763,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
|
|
|
3750
3763
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
|
3751
3764
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
|
3752
3765
|
|
|
3766
|
+
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
|
|
3767
|
+
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
|
|
3768
|
+
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
|
|
3769
|
+
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
|
|
3770
|
+
|
|
3771
|
+
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
|
|
3772
|
+
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
|
|
3773
|
+
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
|
|
3774
|
+
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
|
|
3775
|
+
|
|
3753
3776
|
template<typename T0, typename T1, short NR0, typename args_t>
|
|
3754
3777
|
void kernel_mul_mv_t_t_impl(
|
|
3755
3778
|
args_t args,
|
|
@@ -4437,7 +4460,7 @@ kernel void kernel_im2col(
|
|
|
4437
4460
|
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
|
4438
4461
|
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
4439
4462
|
|
|
4440
|
-
// TODO:
|
|
4463
|
+
// TODO: obsolete -- remove
|
|
4441
4464
|
//typedef void (im2col_ext_t)(
|
|
4442
4465
|
// constant ggml_metal_kargs_im2col & args,
|
|
4443
4466
|
// device const float * x,
|
|
@@ -4749,7 +4772,9 @@ kernel void kernel_conv_transpose_2d<half>(
|
|
|
4749
4772
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4750
4773
|
uint3 ntg[[threads_per_threadgroup]]);
|
|
4751
4774
|
|
|
4752
|
-
|
|
4775
|
+
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
|
|
4776
|
+
|
|
4777
|
+
kernel void kernel_upscale_nearest_f32(
|
|
4753
4778
|
constant ggml_metal_kargs_upscale & args,
|
|
4754
4779
|
device const char * src0,
|
|
4755
4780
|
device char * dst,
|
|
@@ -4775,6 +4800,156 @@ kernel void kernel_upscale_f32(
|
|
|
4775
4800
|
}
|
|
4776
4801
|
}
|
|
4777
4802
|
|
|
4803
|
+
static inline float bilinear_tri(float x) {
|
|
4804
|
+
return MAX(0.0f, 1.0f - fabs(x));
|
|
4805
|
+
}
|
|
4806
|
+
|
|
4807
|
+
kernel void kernel_upscale_bilinear_f32(
|
|
4808
|
+
constant ggml_metal_kargs_upscale & args,
|
|
4809
|
+
device const char * src0,
|
|
4810
|
+
device char * dst,
|
|
4811
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4812
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4813
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
4814
|
+
|
|
4815
|
+
const int64_t i3 = tgpig.z;
|
|
4816
|
+
const int64_t i2 = tgpig.y;
|
|
4817
|
+
const int64_t i1 = tgpig.x;
|
|
4818
|
+
|
|
4819
|
+
const int64_t i03 = i3 / args.sf3;
|
|
4820
|
+
const int64_t i02 = i2 / args.sf2;
|
|
4821
|
+
|
|
4822
|
+
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
|
4823
|
+
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
|
|
4824
|
+
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
|
|
4825
|
+
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
|
|
4826
|
+
|
|
4827
|
+
src0 += i03*args.nb03 + i02*args.nb02;
|
|
4828
|
+
|
|
4829
|
+
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
|
4830
|
+
|
|
4831
|
+
if (FC_upscale_aa) {
|
|
4832
|
+
const float support0 = MAX(1.0f, 1.0f / args.sf0);
|
|
4833
|
+
const float invscale0 = 1.0f / support0;
|
|
4834
|
+
const float support1 = MAX(1.0f, 1.0f / args.sf1);
|
|
4835
|
+
const float invscale1 = 1.0f / support1;
|
|
4836
|
+
|
|
4837
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
4838
|
+
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
|
4839
|
+
|
|
4840
|
+
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
|
|
4841
|
+
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
|
|
4842
|
+
|
|
4843
|
+
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
|
|
4844
|
+
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
|
|
4845
|
+
|
|
4846
|
+
float sum = 0.0f;
|
|
4847
|
+
float wsum = 0.0f;
|
|
4848
|
+
|
|
4849
|
+
for (int64_t sy = y_min; sy < y_max; ++sy) {
|
|
4850
|
+
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
|
|
4851
|
+
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
|
4852
|
+
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
|
4853
|
+
const float w = wx * wy;
|
|
4854
|
+
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
|
4855
|
+
sum += (*src_ptr) * w;
|
|
4856
|
+
wsum += w;
|
|
4857
|
+
}
|
|
4858
|
+
}
|
|
4859
|
+
|
|
4860
|
+
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
|
|
4861
|
+
dst_ptr[i0] = v;
|
|
4862
|
+
}
|
|
4863
|
+
} else {
|
|
4864
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
4865
|
+
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
|
4866
|
+
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
|
|
4867
|
+
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
|
|
4868
|
+
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
|
|
4869
|
+
|
|
4870
|
+
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
|
|
4871
|
+
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
|
|
4872
|
+
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
|
|
4873
|
+
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
|
|
4874
|
+
|
|
4875
|
+
const float v =
|
|
4876
|
+
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
|
|
4877
|
+
(*src10) * fd0 * (1.0f - fd1) +
|
|
4878
|
+
(*src01) * (1.0f - fd0) * fd1 +
|
|
4879
|
+
(*src11) * fd0 * fd1;
|
|
4880
|
+
|
|
4881
|
+
dst_ptr[i0] = v;
|
|
4882
|
+
}
|
|
4883
|
+
}
|
|
4884
|
+
}
|
|
4885
|
+
|
|
4886
|
+
static inline float bicubic_weight1(float x) {
|
|
4887
|
+
const float a = -0.75f;
|
|
4888
|
+
return ((a + 2) * x - (a + 3)) * x * x + 1;
|
|
4889
|
+
}
|
|
4890
|
+
|
|
4891
|
+
static inline float bicubic_weight2(float x) {
|
|
4892
|
+
const float a = -0.75f;
|
|
4893
|
+
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
|
4894
|
+
}
|
|
4895
|
+
|
|
4896
|
+
kernel void kernel_upscale_bicubic_f32(
|
|
4897
|
+
constant ggml_metal_kargs_upscale & args,
|
|
4898
|
+
device const char * src0,
|
|
4899
|
+
device char * dst,
|
|
4900
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4901
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
4902
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
4903
|
+
|
|
4904
|
+
const int64_t i3 = tgpig.z;
|
|
4905
|
+
const int64_t i2 = tgpig.y;
|
|
4906
|
+
const int64_t i1 = tgpig.x;
|
|
4907
|
+
|
|
4908
|
+
const int64_t i03 = i3 / args.sf3;
|
|
4909
|
+
const int64_t i02 = i2 / args.sf2;
|
|
4910
|
+
|
|
4911
|
+
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
|
4912
|
+
const int64_t i01 = (int64_t)floor(f01);
|
|
4913
|
+
const float fd1 = f01 - (float)i01;
|
|
4914
|
+
|
|
4915
|
+
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
|
|
4916
|
+
const float w_y1 = bicubic_weight1(fd1);
|
|
4917
|
+
const float w_y2 = bicubic_weight1(1.0f - fd1);
|
|
4918
|
+
const float w_y3 = bicubic_weight2(2.0f - fd1);
|
|
4919
|
+
|
|
4920
|
+
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
|
|
4921
|
+
|
|
4922
|
+
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
|
|
4923
|
+
|
|
4924
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
4925
|
+
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
|
4926
|
+
const int64_t i00 = (int64_t)floor(f00);
|
|
4927
|
+
const float fd0 = f00 - (float)i00;
|
|
4928
|
+
|
|
4929
|
+
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
|
|
4930
|
+
const float w_x1 = bicubic_weight1(fd0);
|
|
4931
|
+
const float w_x2 = bicubic_weight1(1.0f - fd0);
|
|
4932
|
+
const float w_x3 = bicubic_weight2(2.0f - fd0);
|
|
4933
|
+
|
|
4934
|
+
float sum = 0.0f;
|
|
4935
|
+
|
|
4936
|
+
for (int dy = -1; dy <= 2; ++dy) {
|
|
4937
|
+
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
|
|
4938
|
+
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
|
|
4939
|
+
|
|
4940
|
+
for (int dx = -1; dx <= 2; ++dx) {
|
|
4941
|
+
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
|
4942
|
+
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
|
4943
|
+
|
|
4944
|
+
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
|
4945
|
+
sum += (*src_ptr) * wx * wy;
|
|
4946
|
+
}
|
|
4947
|
+
}
|
|
4948
|
+
|
|
4949
|
+
dst_ptr[i0] = sum;
|
|
4950
|
+
}
|
|
4951
|
+
}
|
|
4952
|
+
|
|
4778
4953
|
kernel void kernel_pad_f32(
|
|
4779
4954
|
constant ggml_metal_kargs_pad & args,
|
|
4780
4955
|
device const char * src0,
|
|
@@ -5114,24 +5289,6 @@ kernel void kernel_argsort_merge_f32_i32(
|
|
|
5114
5289
|
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
5115
5290
|
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
5116
5291
|
|
|
5117
|
-
kernel void kernel_leaky_relu_f32(
|
|
5118
|
-
constant ggml_metal_kargs_leaky_relu & args,
|
|
5119
|
-
device const float * src0,
|
|
5120
|
-
device float * dst,
|
|
5121
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
5122
|
-
const float x = src0[tpig];
|
|
5123
|
-
dst[tpig] = x > 0.0f ? x : x * args.slope;
|
|
5124
|
-
}
|
|
5125
|
-
|
|
5126
|
-
kernel void kernel_leaky_relu_f32_4(
|
|
5127
|
-
constant ggml_metal_kargs_leaky_relu & args,
|
|
5128
|
-
device const float4 * src0,
|
|
5129
|
-
device float4 * dst,
|
|
5130
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
5131
|
-
const float4 x = src0[tpig];
|
|
5132
|
-
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
|
5133
|
-
}
|
|
5134
|
-
|
|
5135
5292
|
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
|
5136
5293
|
|
|
5137
5294
|
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
|
|
@@ -5208,6 +5365,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
|
|
|
5208
5365
|
// scan the blocks of the mask that are not masked
|
|
5209
5366
|
// 0 - masked (i.e. full of -INF, skip)
|
|
5210
5367
|
// 1 - not masked (i.e. at least one element of the mask is not -INF)
|
|
5368
|
+
// 2 - all zero
|
|
5211
5369
|
kernel void kernel_flash_attn_ext_blk(
|
|
5212
5370
|
constant ggml_metal_kargs_flash_attn_ext_blk & args,
|
|
5213
5371
|
device const char * mask,
|
|
@@ -5229,27 +5387,29 @@ kernel void kernel_flash_attn_ext_blk(
|
|
|
5229
5387
|
|
|
5230
5388
|
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
|
|
5231
5389
|
|
|
5232
|
-
// fast route
|
|
5233
|
-
if (res == 0) {
|
|
5234
|
-
if (simd_max(*mask_src) > -MAXHALF/2) {
|
|
5235
|
-
res = 1;
|
|
5236
|
-
}
|
|
5237
|
-
}
|
|
5238
|
-
|
|
5239
5390
|
// detailed check of the elements of the block
|
|
5240
5391
|
if ((C > NW || Q > 1) && res == 0) {
|
|
5241
|
-
half
|
|
5392
|
+
half mmin = MAXHALF;
|
|
5393
|
+
half mmax = -MAXHALF;
|
|
5242
5394
|
|
|
5243
5395
|
FOR_UNROLL (short j = 0; j < Q; ++j) {
|
|
5244
5396
|
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
|
|
5245
|
-
|
|
5397
|
+
mmin = min(mmin, mask_src[ii*NW]);
|
|
5398
|
+
mmax = max(mmax, mask_src[ii*NW]);
|
|
5246
5399
|
}
|
|
5247
5400
|
|
|
5248
5401
|
mask_src += args.nb31/2;
|
|
5249
5402
|
}
|
|
5250
5403
|
|
|
5251
|
-
|
|
5252
|
-
|
|
5404
|
+
mmin = simd_min(mmin);
|
|
5405
|
+
mmax = simd_max(mmax);
|
|
5406
|
+
|
|
5407
|
+
if (mmax > -MAXHALF) {
|
|
5408
|
+
if (mmin == 0.0 && mmax == 0.0) {
|
|
5409
|
+
res = 2;
|
|
5410
|
+
} else {
|
|
5411
|
+
res = 1;
|
|
5412
|
+
}
|
|
5253
5413
|
}
|
|
5254
5414
|
}
|
|
5255
5415
|
|
|
@@ -5491,9 +5651,13 @@ void kernel_flash_attn_ext_impl(
|
|
|
5491
5651
|
ic = 0;
|
|
5492
5652
|
}
|
|
5493
5653
|
|
|
5654
|
+
char blk_cur = 1;
|
|
5655
|
+
|
|
5494
5656
|
// read the mask into shared mem
|
|
5495
5657
|
if (FC_flash_attn_ext_has_mask) {
|
|
5496
|
-
|
|
5658
|
+
blk_cur = blk[ic0];
|
|
5659
|
+
|
|
5660
|
+
if (blk_cur == 0) {
|
|
5497
5661
|
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
5498
5662
|
pm2[jj] += NW;
|
|
5499
5663
|
}
|
|
@@ -5501,16 +5665,22 @@ void kernel_flash_attn_ext_impl(
|
|
|
5501
5665
|
continue;
|
|
5502
5666
|
}
|
|
5503
5667
|
|
|
5504
|
-
|
|
5505
|
-
|
|
5668
|
+
if (blk_cur == 1) {
|
|
5669
|
+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
5670
|
+
const short j = jj*NSG + sgitg;
|
|
5506
5671
|
|
|
5507
|
-
|
|
5508
|
-
|
|
5509
|
-
|
|
5510
|
-
|
|
5511
|
-
|
|
5672
|
+
if (FC_flash_attn_ext_bc_mask) {
|
|
5673
|
+
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
|
|
5674
|
+
} else {
|
|
5675
|
+
sm2[j*SH + tiisg] = pm2[jj][tiisg];
|
|
5676
|
+
}
|
|
5512
5677
|
|
|
5513
|
-
|
|
5678
|
+
pm2[jj] += NW;
|
|
5679
|
+
}
|
|
5680
|
+
} else if (blk_cur == 2) {
|
|
5681
|
+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
|
5682
|
+
pm2[jj] += NW;
|
|
5683
|
+
}
|
|
5514
5684
|
}
|
|
5515
5685
|
|
|
5516
5686
|
#if 0
|
|
@@ -5552,9 +5722,7 @@ void kernel_flash_attn_ext_impl(
|
|
|
5552
5722
|
|
|
5553
5723
|
constexpr short NC = (C/8)/NSG;
|
|
5554
5724
|
|
|
5555
|
-
|
|
5556
|
-
#pragma unroll (DK <= 64 ? NC : 1)
|
|
5557
|
-
for (short cc = 0; cc < NC; ++cc) {
|
|
5725
|
+
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
|
5558
5726
|
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
|
5559
5727
|
|
|
5560
5728
|
if (DK % 16 != 0) {
|
|
@@ -5575,7 +5743,9 @@ void kernel_flash_attn_ext_impl(
|
|
|
5575
5743
|
k8x8_t mk[2];
|
|
5576
5744
|
q8x8_t mq[2];
|
|
5577
5745
|
|
|
5578
|
-
|
|
5746
|
+
// note: too much unroll can tank the performance for large heads
|
|
5747
|
+
#pragma unroll (MIN(DK8/2, 4*NSG))
|
|
5748
|
+
for (short i = 0; i < DK8/2; ++i) {
|
|
5579
5749
|
simdgroup_barrier(mem_flags::mem_none);
|
|
5580
5750
|
|
|
5581
5751
|
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
|
|
@@ -5675,10 +5845,12 @@ void kernel_flash_attn_ext_impl(
|
|
|
5675
5845
|
}
|
|
5676
5846
|
|
|
5677
5847
|
// mqk = mqk + slope*mask
|
|
5678
|
-
if (
|
|
5679
|
-
|
|
5680
|
-
|
|
5681
|
-
|
|
5848
|
+
if (blk_cur != 2) {
|
|
5849
|
+
if (FC_flash_attn_ext_has_bias) {
|
|
5850
|
+
s2 += s2_t(sm2[j*SH + tiisg])*slope;
|
|
5851
|
+
} else {
|
|
5852
|
+
s2 += s2_t(sm2[j*SH + tiisg]);
|
|
5853
|
+
}
|
|
5682
5854
|
}
|
|
5683
5855
|
|
|
5684
5856
|
M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
|
|
@@ -5749,7 +5921,9 @@ void kernel_flash_attn_ext_impl(
|
|
|
5749
5921
|
pv += 8*NS20;
|
|
5750
5922
|
}
|
|
5751
5923
|
} else {
|
|
5752
|
-
|
|
5924
|
+
constexpr short NC = (C/8)/2;
|
|
5925
|
+
|
|
5926
|
+
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
|
5753
5927
|
s8x8_t vs[2];
|
|
5754
5928
|
|
|
5755
5929
|
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
|
|
@@ -5929,7 +6103,7 @@ template<
|
|
|
5929
6103
|
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
|
5930
6104
|
short DK, // K head size
|
|
5931
6105
|
short DV, // V head size
|
|
5932
|
-
short Q =
|
|
6106
|
+
short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
|
|
5933
6107
|
short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
|
|
5934
6108
|
kernel void kernel_flash_attn_ext(
|
|
5935
6109
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
|
@@ -5952,6 +6126,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
5952
6126
|
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
|
5953
6127
|
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
|
5954
6128
|
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
|
6129
|
+
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
|
5955
6130
|
}
|
|
5956
6131
|
#undef FWD_TMPL
|
|
5957
6132
|
#undef FWD_ARGS
|
|
@@ -6001,6 +6176,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at
|
|
|
6001
6176
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
|
|
6002
6177
|
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
|
|
6003
6178
|
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
|
|
6179
|
+
template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>;
|
|
6004
6180
|
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
|
|
6005
6181
|
|
|
6006
6182
|
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
|
@@ -6015,6 +6191,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at
|
|
|
6015
6191
|
template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
|
6016
6192
|
template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
|
6017
6193
|
template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
|
6194
|
+
template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>;
|
|
6018
6195
|
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
|
6019
6196
|
|
|
6020
6197
|
#if defined(GGML_METAL_HAS_BF16)
|
|
@@ -6030,6 +6207,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at
|
|
|
6030
6207
|
template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
|
6031
6208
|
template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
|
6032
6209
|
template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
|
6210
|
+
template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>;
|
|
6033
6211
|
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
|
6034
6212
|
#endif
|
|
6035
6213
|
|
|
@@ -6045,6 +6223,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at
|
|
|
6045
6223
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
|
6046
6224
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
|
6047
6225
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
|
6226
|
+
template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;
|
|
6048
6227
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
|
6049
6228
|
|
|
6050
6229
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
|
@@ -6059,6 +6238,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at
|
|
|
6059
6238
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
|
6060
6239
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
|
6061
6240
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
|
6241
|
+
template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;
|
|
6062
6242
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
|
6063
6243
|
|
|
6064
6244
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
|
@@ -6073,6 +6253,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at
|
|
|
6073
6253
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
|
6074
6254
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
|
6075
6255
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
|
6256
|
+
template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;
|
|
6076
6257
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
|
6077
6258
|
|
|
6078
6259
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
|
@@ -6087,6 +6268,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at
|
|
|
6087
6268
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
|
6088
6269
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
|
6089
6270
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
|
6271
|
+
template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;
|
|
6090
6272
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
|
6091
6273
|
|
|
6092
6274
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
|
@@ -6101,6 +6283,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at
|
|
|
6101
6283
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
|
6102
6284
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
|
6103
6285
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
|
6286
|
+
template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;
|
|
6104
6287
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
|
6105
6288
|
|
|
6106
6289
|
#undef FA_TYPES
|
|
@@ -6138,11 +6321,10 @@ template<
|
|
|
6138
6321
|
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
|
6139
6322
|
short DK, // K head size
|
|
6140
6323
|
short DV, // V head size
|
|
6141
|
-
short NE,
|
|
6142
|
-
short Q,
|
|
6143
|
-
short C
|
|
6144
|
-
|
|
6145
|
-
void kernel_flash_attn_ext_vec_impl(
|
|
6324
|
+
short NE = 4, // head elements per thread
|
|
6325
|
+
short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup
|
|
6326
|
+
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
|
|
6327
|
+
kernel void kernel_flash_attn_ext_vec(
|
|
6146
6328
|
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
|
6147
6329
|
device const char * q,
|
|
6148
6330
|
device const char * k,
|
|
@@ -6159,6 +6341,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6159
6341
|
static_assert(DV % 32 == 0, "DV must be divisible by 32");
|
|
6160
6342
|
|
|
6161
6343
|
#define NWG (FC_flash_attn_ext_vec_nwg)
|
|
6344
|
+
#define NSG (FC_flash_attn_ext_vec_nsg)
|
|
6162
6345
|
|
|
6163
6346
|
#define NS10 (FC_flash_attn_ext_vec_ns10)
|
|
6164
6347
|
#define NS20 (FC_flash_attn_ext_vec_ns20)
|
|
@@ -6185,14 +6368,14 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6185
6368
|
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
|
|
6186
6369
|
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
|
|
6187
6370
|
|
|
6188
|
-
|
|
6371
|
+
//const short T = PK + NSG*SH; // shared memory size per query in (half)
|
|
6189
6372
|
|
|
6190
|
-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
|
6191
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
|
6192
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH +
|
|
6193
|
-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH +
|
|
6194
|
-
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C +
|
|
6195
|
-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV +
|
|
6373
|
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
|
|
6374
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
|
|
6375
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention
|
|
6376
|
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t
|
|
6377
|
+
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
|
|
6378
|
+
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results
|
|
6196
6379
|
|
|
6197
6380
|
// store the result for all queries in shared memory (the O matrix from the paper)
|
|
6198
6381
|
so4 += tiisg;
|
|
@@ -6210,11 +6393,13 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6210
6393
|
// load heads from Q to shared memory
|
|
6211
6394
|
device const float4 * q4 = (device const float4 *) ((device const char *) q);
|
|
6212
6395
|
|
|
6213
|
-
|
|
6214
|
-
|
|
6215
|
-
|
|
6216
|
-
|
|
6217
|
-
|
|
6396
|
+
if (iq1 < args.ne01) {
|
|
6397
|
+
for (short i = tiisg; i < PK4; i += NW) {
|
|
6398
|
+
if (i < DK4) {
|
|
6399
|
+
sq4[i] = (q4_t) q4[i];
|
|
6400
|
+
} else {
|
|
6401
|
+
sq4[i] = (q4_t) 0.0f;
|
|
6402
|
+
}
|
|
6218
6403
|
}
|
|
6219
6404
|
}
|
|
6220
6405
|
|
|
@@ -6292,7 +6477,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6292
6477
|
}
|
|
6293
6478
|
|
|
6294
6479
|
// skip -INF blocks
|
|
6295
|
-
if (simd_max(sm[tiisg])
|
|
6480
|
+
if (simd_max(sm[tiisg]) <= -MAXHALF) {
|
|
6296
6481
|
continue;
|
|
6297
6482
|
}
|
|
6298
6483
|
|
|
@@ -6566,57 +6751,11 @@ void kernel_flash_attn_ext_vec_impl(
|
|
|
6566
6751
|
}
|
|
6567
6752
|
|
|
6568
6753
|
#undef NWG
|
|
6754
|
+
#undef NSG
|
|
6569
6755
|
#undef NS10
|
|
6570
6756
|
#undef NS20
|
|
6571
6757
|
}
|
|
6572
6758
|
|
|
6573
|
-
template<
|
|
6574
|
-
typename q4_t, // query types in shared memory
|
|
6575
|
-
typename k4_t, // key types in shared memory
|
|
6576
|
-
typename v4_t, // value types in shared memory
|
|
6577
|
-
typename qk_t, // Q*K types
|
|
6578
|
-
typename s_t, // soft-max types
|
|
6579
|
-
typename s4_t,
|
|
6580
|
-
typename o4_t, // attention accumulation types
|
|
6581
|
-
typename kd4_t, // key type in device memory
|
|
6582
|
-
short nl_k,
|
|
6583
|
-
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
|
|
6584
|
-
typename vd4_t, // value type in device memory
|
|
6585
|
-
short nl_v,
|
|
6586
|
-
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
|
6587
|
-
short DK, // K head size
|
|
6588
|
-
short DV, // V head size
|
|
6589
|
-
short NE = 4, // head elements per thread
|
|
6590
|
-
short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
|
|
6591
|
-
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
|
|
6592
|
-
kernel void kernel_flash_attn_ext_vec(
|
|
6593
|
-
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
|
6594
|
-
device const char * q,
|
|
6595
|
-
device const char * k,
|
|
6596
|
-
device const char * v,
|
|
6597
|
-
device const char * mask,
|
|
6598
|
-
device const char * sinks,
|
|
6599
|
-
device const char * pad,
|
|
6600
|
-
device char * dst,
|
|
6601
|
-
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
|
6602
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6603
|
-
ushort tiisg[[thread_index_in_simdgroup]],
|
|
6604
|
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6605
|
-
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
|
|
6606
|
-
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
|
6607
|
-
switch (FC_flash_attn_ext_vec_nsg) {
|
|
6608
|
-
// note: disabled cases to reduce library load time
|
|
6609
|
-
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
|
6610
|
-
case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
|
6611
|
-
case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
|
6612
|
-
//case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
|
6613
|
-
//case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
|
|
6614
|
-
//case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
|
|
6615
|
-
}
|
|
6616
|
-
#undef FWD_TMPL
|
|
6617
|
-
#undef FWD_ARGS
|
|
6618
|
-
}
|
|
6619
|
-
|
|
6620
6759
|
// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
|
|
6621
6760
|
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
|
6622
6761
|
//
|
|
@@ -6715,6 +6854,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
|
|
|
6715
6854
|
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
|
6716
6855
|
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
|
6717
6856
|
|
|
6857
|
+
template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 320, 256, 2>;
|
|
6858
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 320, 256, 2>;
|
|
6859
|
+
#if defined(GGML_METAL_HAS_BF16)
|
|
6860
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 320, 256, 2>;
|
|
6861
|
+
#endif
|
|
6862
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 320, 256, 2>;
|
|
6863
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 320, 256, 2>;
|
|
6864
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 320, 256, 2>;
|
|
6865
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>;
|
|
6866
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>;
|
|
6867
|
+
|
|
6718
6868
|
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
|
|
6719
6869
|
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
|
6720
6870
|
#if defined(GGML_METAL_HAS_BF16)
|
|
@@ -8779,6 +8929,26 @@ kernel void kernel_set_rows_f(
|
|
|
8779
8929
|
}
|
|
8780
8930
|
}
|
|
8781
8931
|
|
|
8932
|
+
kernel void kernel_diag_f32(
|
|
8933
|
+
constant ggml_metal_kargs_diag & args,
|
|
8934
|
+
device const char * src0,
|
|
8935
|
+
device char * dst,
|
|
8936
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
8937
|
+
ushort tiitg[[thread_index_in_threadgroup]]) {
|
|
8938
|
+
constexpr short NW = N_SIMDWIDTH;
|
|
8939
|
+
|
|
8940
|
+
const int32_t i3 = tgpig.z;
|
|
8941
|
+
const int32_t i2 = tgpig.y;
|
|
8942
|
+
const int32_t i1 = tgpig.x;
|
|
8943
|
+
|
|
8944
|
+
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
|
|
8945
|
+
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
|
|
8946
|
+
|
|
8947
|
+
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
|
|
8948
|
+
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
|
|
8949
|
+
}
|
|
8950
|
+
}
|
|
8951
|
+
|
|
8782
8952
|
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
|
8783
8953
|
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
|
8784
8954
|
|
|
@@ -8797,7 +8967,9 @@ kernel void kernel_mul_mm(
|
|
|
8797
8967
|
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
|
8798
8968
|
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
|
8799
8969
|
|
|
8970
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
8800
8971
|
threadgroup float * sc = (threadgroup float *)(shmem);
|
|
8972
|
+
#endif
|
|
8801
8973
|
|
|
8802
8974
|
constexpr int NR0 = 64;
|
|
8803
8975
|
constexpr int NR1 = 32;
|
|
@@ -8920,8 +9092,8 @@ kernel void kernel_mul_mm(
|
|
|
8920
9092
|
const short sx = (tiitg%NL1);
|
|
8921
9093
|
const short sy = (tiitg/NL1)/8;
|
|
8922
9094
|
|
|
8923
|
-
|
|
8924
|
-
|
|
9095
|
+
//const short dx = sx;
|
|
9096
|
+
//const short dy = sy;
|
|
8925
9097
|
|
|
8926
9098
|
const short ly = (tiitg/NL1)%8;
|
|
8927
9099
|
|
|
@@ -9153,6 +9325,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
|
|
9153
9325
|
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
|
9154
9326
|
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
|
9155
9327
|
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
|
9328
|
+
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
|
9156
9329
|
|
|
9157
9330
|
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
9158
9331
|
kernel void kernel_mul_mm_id(
|
|
@@ -9170,7 +9343,9 @@ kernel void kernel_mul_mm_id(
|
|
|
9170
9343
|
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
|
9171
9344
|
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
|
9172
9345
|
|
|
9346
|
+
#ifdef GGML_METAL_HAS_TENSOR
|
|
9173
9347
|
threadgroup float * sc = (threadgroup float *)(shmem);
|
|
9348
|
+
#endif
|
|
9174
9349
|
|
|
9175
9350
|
constexpr int NR0 = 64;
|
|
9176
9351
|
constexpr int NR1 = 32;
|
|
@@ -9305,8 +9480,8 @@ kernel void kernel_mul_mm_id(
|
|
|
9305
9480
|
const short sx = (tiitg%NL1);
|
|
9306
9481
|
const short sy = (tiitg/NL1)/8;
|
|
9307
9482
|
|
|
9308
|
-
|
|
9309
|
-
|
|
9483
|
+
//const short dx = sx;
|
|
9484
|
+
//const short dy = sy;
|
|
9310
9485
|
|
|
9311
9486
|
const short ly = (tiitg/NL1)%8;
|
|
9312
9487
|
|
|
@@ -9869,6 +10044,74 @@ kernel void kernel_pool_2d_avg_f32(
|
|
|
9869
10044
|
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
|
9870
10045
|
}
|
|
9871
10046
|
|
|
10047
|
+
|
|
10048
|
+
kernel void kernel_pool_1d_max_f32(
|
|
10049
|
+
constant ggml_metal_kargs_pool_1d & args,
|
|
10050
|
+
device const float * src,
|
|
10051
|
+
device float * dst,
|
|
10052
|
+
uint gid [[thread_position_in_grid]]
|
|
10053
|
+
) {
|
|
10054
|
+
|
|
10055
|
+
if (gid >= args.np) {
|
|
10056
|
+
return;
|
|
10057
|
+
}
|
|
10058
|
+
|
|
10059
|
+
const int ow = (int)gid % args.OW;
|
|
10060
|
+
const int row = (int)gid / args.OW;
|
|
10061
|
+
|
|
10062
|
+
const int base = ow * args.s0 - args.p0;
|
|
10063
|
+
|
|
10064
|
+
float acc = -INFINITY;
|
|
10065
|
+
|
|
10066
|
+
const int src_off = row * args.IW;
|
|
10067
|
+
const int dst_off = row * args.OW;
|
|
10068
|
+
|
|
10069
|
+
for (int ki = 0; ki < args.k0; ++ki) {
|
|
10070
|
+
int j = base + ki;
|
|
10071
|
+
if (j < 0 || j >= args.IW){
|
|
10072
|
+
continue;
|
|
10073
|
+
}
|
|
10074
|
+
float v = src[src_off + j];
|
|
10075
|
+
acc = max(acc, v);
|
|
10076
|
+
}
|
|
10077
|
+
|
|
10078
|
+
dst[dst_off + ow] = acc;
|
|
10079
|
+
}
|
|
10080
|
+
|
|
10081
|
+
kernel void kernel_pool_1d_avg_f32(
|
|
10082
|
+
constant ggml_metal_kargs_pool_1d & args,
|
|
10083
|
+
device const float * src,
|
|
10084
|
+
device float * dst,
|
|
10085
|
+
uint gid [[thread_position_in_grid]]
|
|
10086
|
+
) {
|
|
10087
|
+
|
|
10088
|
+
if (gid >= args.np) {
|
|
10089
|
+
return;
|
|
10090
|
+
}
|
|
10091
|
+
|
|
10092
|
+
const int ow = (int)gid % args.OW;
|
|
10093
|
+
const int row = (int)gid / args.OW;
|
|
10094
|
+
|
|
10095
|
+
const int base = ow * args.s0 - args.p0;
|
|
10096
|
+
|
|
10097
|
+
float acc = 0.0f;
|
|
10098
|
+
int cnt = 0;
|
|
10099
|
+
|
|
10100
|
+
const int src_off = row * args.IW;
|
|
10101
|
+
const int dst_off = row * args.OW;
|
|
10102
|
+
|
|
10103
|
+
for (int ki = 0; ki < args.k0; ++ki) {
|
|
10104
|
+
const int j = base + ki;
|
|
10105
|
+
if (j < 0 || j >= args.IW) {
|
|
10106
|
+
continue;
|
|
10107
|
+
}
|
|
10108
|
+
acc += src[src_off + j];
|
|
10109
|
+
cnt += 1;
|
|
10110
|
+
}
|
|
10111
|
+
|
|
10112
|
+
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
|
|
10113
|
+
}
|
|
10114
|
+
|
|
9872
10115
|
kernel void kernel_opt_step_adamw_f32(
|
|
9873
10116
|
constant ggml_metal_kargs_opt_step_adamw & args,
|
|
9874
10117
|
device float * x,
|
|
@@ -9919,7 +10162,7 @@ kernel void kernel_opt_step_sgd_f32(
|
|
|
9919
10162
|
|
|
9920
10163
|
template<typename T>
|
|
9921
10164
|
kernel void kernel_memset(
|
|
9922
|
-
constant
|
|
10165
|
+
constant ggml_metal_kargs_memset & args,
|
|
9923
10166
|
device T * dst,
|
|
9924
10167
|
uint tpig[[thread_position_in_grid]]) {
|
|
9925
10168
|
dst[tpig] = args.val;
|