whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -18,6 +18,10 @@
|
|
|
18
18
|
|
|
19
19
|
#include "common.cuh"
|
|
20
20
|
|
|
21
|
+
// On Volta each warp is doing 4 8x8 mma operations in parallel.
|
|
22
|
+
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
|
|
23
|
+
// However, the i indices in this file are by default permuted to simplify the index calculations.
|
|
24
|
+
// #define GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
21
25
|
|
|
22
26
|
#if CUDART_VERSION >= 11080
|
|
23
27
|
|
|
@@ -64,15 +68,59 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
|
|
64
68
|
|
|
65
69
|
namespace ggml_cuda_mma {
|
|
66
70
|
|
|
71
|
+
// Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
|
|
72
|
+
// effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
|
|
73
|
+
// In those cases the data can be split in different ways across the warp.
|
|
74
|
+
enum data_layout {
|
|
75
|
+
// By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
|
|
76
|
+
// For the A/C matrices this means I major == row major, J major == column major.
|
|
77
|
+
// For the B matrix this means I major == column major, J major == row major.
|
|
78
|
+
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
|
79
|
+
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
|
|
80
|
+
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
|
81
|
+
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
|
|
82
|
+
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
|
83
|
+
};
|
|
84
|
+
// Implemented mma combinations are:
|
|
85
|
+
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
|
86
|
+
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
|
87
|
+
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
|
88
|
+
|
|
89
|
+
static constexpr bool is_i_major(const data_layout dl) {
|
|
90
|
+
return dl == DATA_LAYOUT_I_MAJOR ||
|
|
91
|
+
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
static constexpr __device__ data_layout get_input_data_layout() {
|
|
95
|
+
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
96
|
+
return DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
97
|
+
#else
|
|
98
|
+
return DATA_LAYOUT_I_MAJOR;
|
|
99
|
+
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
|
103
|
+
struct tile {};
|
|
104
|
+
|
|
67
105
|
template <int I_, int J_, typename T>
|
|
68
|
-
struct tile {
|
|
69
|
-
static constexpr int
|
|
70
|
-
static constexpr int
|
|
106
|
+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
|
|
107
|
+
static constexpr int I = I_;
|
|
108
|
+
static constexpr int J = J_;
|
|
109
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
71
110
|
|
|
72
|
-
#if defined(
|
|
111
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
73
112
|
static constexpr int ne = I * J / 64;
|
|
74
113
|
T x[ne] = {0};
|
|
75
114
|
|
|
115
|
+
static constexpr __device__ bool supported() {
|
|
116
|
+
if (I == 64 && J == 2) return true;
|
|
117
|
+
if (I == 16 && J == 8) return true;
|
|
118
|
+
if (I == 32 && J == 4) return true;
|
|
119
|
+
if (I == 16 && J == 16) return true;
|
|
120
|
+
if (I == 32 && J == 32) return true;
|
|
121
|
+
return false;
|
|
122
|
+
}
|
|
123
|
+
|
|
76
124
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
77
125
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
78
126
|
return threadIdx.x % 16;
|
|
@@ -81,11 +129,12 @@ namespace ggml_cuda_mma {
|
|
|
81
129
|
} else if constexpr (I == 32 && J == 4) {
|
|
82
130
|
return threadIdx.x % 32;
|
|
83
131
|
} else if constexpr (I == 16 && J == 16) {
|
|
84
|
-
return
|
|
132
|
+
return threadIdx.x % 16;
|
|
85
133
|
} else if constexpr (I == 32 && J == 32) {
|
|
86
|
-
return
|
|
134
|
+
return threadIdx.x % 32;
|
|
87
135
|
} else {
|
|
88
|
-
|
|
136
|
+
NO_DEVICE_CODE;
|
|
137
|
+
return -1;
|
|
89
138
|
}
|
|
90
139
|
}
|
|
91
140
|
|
|
@@ -97,26 +146,115 @@ namespace ggml_cuda_mma {
|
|
|
97
146
|
} else if constexpr (I == 32 && J == 4) {
|
|
98
147
|
return 2 * (threadIdx.x / 32) + l;
|
|
99
148
|
} else if constexpr (I == 16 && J == 16) {
|
|
100
|
-
return threadIdx.x
|
|
149
|
+
return 4 * (threadIdx.x / 16) + l;
|
|
101
150
|
} else if constexpr (I == 32 && J == 32) {
|
|
102
|
-
return threadIdx.x %
|
|
151
|
+
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
|
103
152
|
} else {
|
|
104
|
-
|
|
153
|
+
NO_DEVICE_CODE;
|
|
154
|
+
return -1;
|
|
105
155
|
}
|
|
106
156
|
}
|
|
157
|
+
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
158
|
+
static constexpr int ne = I * J / 32;
|
|
159
|
+
T x[ne] = {0};
|
|
160
|
+
|
|
161
|
+
static constexpr __device__ bool supported() {
|
|
162
|
+
if (I == 32 && J == 8) return true;
|
|
163
|
+
return false;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
167
|
+
if constexpr (I == 32 && J == 8) {
|
|
168
|
+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
169
|
+
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
|
|
107
170
|
#else
|
|
171
|
+
return (l & 2) + (threadIdx.x & ~2);
|
|
172
|
+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
173
|
+
} else {
|
|
174
|
+
NO_DEVICE_CODE;
|
|
175
|
+
return -1;
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
180
|
+
if constexpr (I == 32 && J == 8) {
|
|
181
|
+
return (threadIdx.x & 2) + (l & (4 + 1));
|
|
182
|
+
} else {
|
|
183
|
+
NO_DEVICE_CODE;
|
|
184
|
+
return -1;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
108
188
|
static constexpr int ne = I * J / 32;
|
|
109
189
|
T x[ne] = {0};
|
|
110
190
|
|
|
191
|
+
static constexpr __device__ bool supported() {
|
|
192
|
+
if (I == 16 && J == 16) return true;
|
|
193
|
+
if (I == 16 && J == 8) return true;
|
|
194
|
+
if (I == 16 && J == 4) return true;
|
|
195
|
+
return false;
|
|
196
|
+
}
|
|
197
|
+
|
|
111
198
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
112
|
-
if constexpr (
|
|
199
|
+
if constexpr (supported()) {
|
|
200
|
+
return threadIdx.x % 16;
|
|
201
|
+
} else {
|
|
202
|
+
NO_DEVICE_CODE;
|
|
203
|
+
return -1;
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
208
|
+
if constexpr (I == 16 && J == 16) {
|
|
209
|
+
#if defined(RDNA3)
|
|
210
|
+
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) {
|
|
211
|
+
// matrix C
|
|
212
|
+
return 2 * l + (threadIdx.x / 16);
|
|
213
|
+
} else {
|
|
214
|
+
// matrix A&B
|
|
215
|
+
return l;
|
|
216
|
+
}
|
|
217
|
+
#else
|
|
218
|
+
// matrix C is the transposed matrix A&B on RDNA4
|
|
219
|
+
return ne * (threadIdx.x / 16) + l;
|
|
220
|
+
#endif // defined(RDNA3)
|
|
221
|
+
} else if constexpr (I == 16 && J == 8) {
|
|
222
|
+
// mmq input for RDNA4
|
|
223
|
+
return ne * (threadIdx.x / 16) + l;
|
|
224
|
+
} else if constexpr (I == 16 && J == 4) {
|
|
225
|
+
return ne * (threadIdx.x / 16) + l;
|
|
226
|
+
} else {
|
|
227
|
+
NO_DEVICE_CODE;
|
|
228
|
+
return -1;
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
#else
|
|
232
|
+
static constexpr int ne = I * J / 32;
|
|
233
|
+
T x[ne] = {0};
|
|
234
|
+
|
|
235
|
+
static constexpr __device__ bool supported() {
|
|
236
|
+
if (I == 8 && J == 4) return true;
|
|
237
|
+
if (I == 8 && J == 8) return true;
|
|
238
|
+
if (I == 16 && J == 8) return true;
|
|
239
|
+
if (I == 16 && J == 16) return true;
|
|
240
|
+
if (I == 32 && J == 8) return true;
|
|
241
|
+
return false;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
245
|
+
if constexpr (I == 8 && J == 4) {
|
|
246
|
+
return threadIdx.x / 4;
|
|
247
|
+
} else if constexpr (I == 8 && J == 8) {
|
|
113
248
|
return threadIdx.x / 4;
|
|
114
249
|
} else if constexpr (I == 16 && J == 8) {
|
|
115
|
-
return (l / 2) * 8 + threadIdx.x / 4;
|
|
250
|
+
return ((l / 2) * 8) + (threadIdx.x / 4);
|
|
116
251
|
} else if constexpr (I == 16 && J == 16) {
|
|
117
|
-
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
|
252
|
+
return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
|
|
253
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
254
|
+
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
|
|
118
255
|
} else {
|
|
119
|
-
|
|
256
|
+
NO_DEVICE_CODE;
|
|
257
|
+
return -1;
|
|
120
258
|
}
|
|
121
259
|
}
|
|
122
260
|
|
|
@@ -124,82 +262,395 @@ namespace ggml_cuda_mma {
|
|
|
124
262
|
if constexpr (I == 8 && J == 4) {
|
|
125
263
|
return threadIdx.x % 4;
|
|
126
264
|
} else if constexpr (I == 8 && J == 8) {
|
|
127
|
-
return
|
|
265
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
128
266
|
} else if constexpr (I == 16 && J == 8) {
|
|
129
|
-
return
|
|
267
|
+
return ((threadIdx.x % 4) * 2) + (l % 2);
|
|
130
268
|
} else if constexpr (I == 16 && J == 16) {
|
|
131
|
-
return
|
|
269
|
+
return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
|
|
270
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
271
|
+
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
|
|
132
272
|
} else {
|
|
133
|
-
|
|
273
|
+
NO_DEVICE_CODE;
|
|
274
|
+
return -1;
|
|
134
275
|
}
|
|
135
276
|
}
|
|
136
277
|
#endif // defined(GGML_USE_HIP)
|
|
137
278
|
};
|
|
138
279
|
|
|
139
280
|
template <int I_, int J_>
|
|
140
|
-
struct tile<I_, J_, half2> {
|
|
141
|
-
static constexpr int
|
|
142
|
-
static constexpr int
|
|
281
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
|
|
282
|
+
static constexpr int I = I_;
|
|
283
|
+
static constexpr int J = J_;
|
|
284
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
285
|
+
|
|
286
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
143
287
|
static constexpr int ne = I * J / WARP_SIZE;
|
|
144
288
|
half2 x[ne] = {{0.0f, 0.0f}};
|
|
145
289
|
|
|
290
|
+
static constexpr __device__ bool supported() {
|
|
291
|
+
if (I == 32 && J == 4) return true;
|
|
292
|
+
return false;
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
296
|
+
if constexpr (I == 32 && J == 4) {
|
|
297
|
+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
298
|
+
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
|
299
|
+
#else
|
|
300
|
+
return threadIdx.x;
|
|
301
|
+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
302
|
+
} else {
|
|
303
|
+
NO_DEVICE_CODE;
|
|
304
|
+
return -1;
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
309
|
+
if constexpr (I == 32 && J == 4) {
|
|
310
|
+
return l;
|
|
311
|
+
} else {
|
|
312
|
+
NO_DEVICE_CODE;
|
|
313
|
+
return -1;
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
317
|
+
static constexpr int ne = I * J / 32;
|
|
318
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
319
|
+
|
|
320
|
+
static constexpr __device__ bool supported() {
|
|
321
|
+
if (I == 16 && J == 8) return true;
|
|
322
|
+
return false;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
326
|
+
if constexpr (I == 16 && J == 8) {
|
|
327
|
+
return threadIdx.x % 16;
|
|
328
|
+
} else {
|
|
329
|
+
NO_DEVICE_CODE;
|
|
330
|
+
return -1;
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
335
|
+
if constexpr (I == 16 && J == 8) {
|
|
336
|
+
return ne * (threadIdx.x / 16) + l;
|
|
337
|
+
} else {
|
|
338
|
+
NO_DEVICE_CODE;
|
|
339
|
+
return -1;
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
343
|
+
static constexpr int ne = I * J / 64;
|
|
344
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
345
|
+
|
|
346
|
+
static constexpr __device__ bool supported() {
|
|
347
|
+
if (I == 16 && J == 8) return true;
|
|
348
|
+
return false;
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
352
|
+
if constexpr (I == 16 && J == 8) {
|
|
353
|
+
return threadIdx.x % 16;
|
|
354
|
+
} else {
|
|
355
|
+
NO_DEVICE_CODE;
|
|
356
|
+
return -1;
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
361
|
+
if constexpr (I == 16 && J == 8) {
|
|
362
|
+
return ne * (threadIdx.x / 16) + l;
|
|
363
|
+
} else {
|
|
364
|
+
NO_DEVICE_CODE;
|
|
365
|
+
return -1;
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
#else
|
|
369
|
+
static constexpr int ne = I * J / WARP_SIZE;
|
|
370
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
371
|
+
|
|
372
|
+
static constexpr __device__ bool supported() {
|
|
373
|
+
if (I == 8 && J == 4) return true;
|
|
374
|
+
if (I == 8 && J == 8) return true;
|
|
375
|
+
if (I == 16 && J == 8) return true;
|
|
376
|
+
if (I == 16 && J == 16) return true;
|
|
377
|
+
if (I == 32 && J == 8) return true;
|
|
378
|
+
return false;
|
|
379
|
+
}
|
|
380
|
+
|
|
146
381
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
147
382
|
if constexpr (I == 8 && J == 8) {
|
|
148
383
|
return threadIdx.x / 4;
|
|
149
384
|
} else if constexpr (I == 16 && J == 4) {
|
|
150
|
-
return l * 8 + threadIdx.x / 4;
|
|
385
|
+
return (l * 8) + (threadIdx.x / 4);
|
|
151
386
|
} else if constexpr (I == 16 && J == 8) {
|
|
152
|
-
return (l % 2) * 8 + threadIdx.x / 4;
|
|
387
|
+
return ((l % 2) * 8) + (threadIdx.x / 4);
|
|
388
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
389
|
+
return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
|
|
153
390
|
} else {
|
|
154
|
-
|
|
391
|
+
NO_DEVICE_CODE;
|
|
392
|
+
return -1;
|
|
155
393
|
}
|
|
156
394
|
}
|
|
157
395
|
|
|
158
396
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
159
397
|
if constexpr (I == 8 && J == 8) {
|
|
160
|
-
return l * 4 + threadIdx.x % 4;
|
|
398
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
161
399
|
} else if constexpr (I == 16 && J == 4) {
|
|
162
400
|
return threadIdx.x % 4;
|
|
163
401
|
} else if constexpr (I == 16 && J == 8) {
|
|
164
|
-
return (l / 2) * 4 + threadIdx.x % 4;
|
|
402
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
403
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
404
|
+
return ((l & 2) * 2) + (threadIdx.x % 4);
|
|
165
405
|
} else {
|
|
166
|
-
|
|
406
|
+
NO_DEVICE_CODE;
|
|
407
|
+
return -1;
|
|
167
408
|
}
|
|
168
409
|
}
|
|
410
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
169
411
|
};
|
|
170
412
|
|
|
171
413
|
template <int I_, int J_>
|
|
172
|
-
struct tile<I_, J_, nv_bfloat162> {
|
|
173
|
-
static constexpr int
|
|
174
|
-
static constexpr int
|
|
414
|
+
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
|
|
415
|
+
static constexpr int I = I_;
|
|
416
|
+
static constexpr int J = J_;
|
|
417
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
418
|
+
|
|
419
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
420
|
+
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
|
|
421
|
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
422
|
+
|
|
423
|
+
static constexpr __device__ bool supported() {
|
|
424
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
428
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
432
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
|
433
|
+
}
|
|
434
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
435
|
+
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
|
|
436
|
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
437
|
+
|
|
438
|
+
static constexpr __device__ bool supported() {
|
|
439
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
443
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
447
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
|
448
|
+
}
|
|
449
|
+
#else
|
|
175
450
|
static constexpr int ne = I * J / WARP_SIZE;
|
|
176
451
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
177
452
|
|
|
453
|
+
static constexpr __device__ bool supported() {
|
|
454
|
+
if (I == 8 && J == 8) return true;
|
|
455
|
+
if (I == 16 && J == 4) return true;
|
|
456
|
+
if (I == 16 && J == 8) return true;
|
|
457
|
+
return false;
|
|
458
|
+
}
|
|
459
|
+
|
|
178
460
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
179
461
|
if constexpr (I == 8 && J == 8) {
|
|
180
462
|
return threadIdx.x / 4;
|
|
181
463
|
} else if constexpr (I == 16 && J == 4) {
|
|
182
|
-
return l * 8 + threadIdx.x / 4;
|
|
464
|
+
return (l * 8) + (threadIdx.x / 4);
|
|
183
465
|
} else if constexpr (I == 16 && J == 8) {
|
|
184
|
-
return (l % 2) * 8 + threadIdx.x / 4;
|
|
466
|
+
return ((l % 2) * 8) + (threadIdx.x / 4);
|
|
185
467
|
} else {
|
|
186
|
-
|
|
468
|
+
NO_DEVICE_CODE;
|
|
469
|
+
return -1;
|
|
187
470
|
}
|
|
188
471
|
}
|
|
189
472
|
|
|
190
473
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
191
474
|
if constexpr (I == 8 && J == 8) {
|
|
192
|
-
return l * 4 + threadIdx.x % 4;
|
|
475
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
193
476
|
} else if constexpr (I == 16 && J == 4) {
|
|
194
477
|
return threadIdx.x % 4;
|
|
195
478
|
} else if constexpr (I == 16 && J == 8) {
|
|
196
|
-
return (l / 2) * 4 + threadIdx.x % 4;
|
|
479
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
480
|
+
} else {
|
|
481
|
+
NO_DEVICE_CODE;
|
|
482
|
+
return -1;
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
486
|
+
};
|
|
487
|
+
|
|
488
|
+
template <int I_, int J_, typename T>
|
|
489
|
+
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
|
|
490
|
+
static constexpr int I = I_;
|
|
491
|
+
static constexpr int J = J_;
|
|
492
|
+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
|
|
493
|
+
|
|
494
|
+
static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
|
|
495
|
+
T x[ne] = {0};
|
|
496
|
+
|
|
497
|
+
static constexpr __device__ bool supported() {
|
|
498
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
502
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
506
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
|
507
|
+
}
|
|
508
|
+
};
|
|
509
|
+
|
|
510
|
+
template <int I_, int J_, typename T>
|
|
511
|
+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
512
|
+
static constexpr int I = I_;
|
|
513
|
+
static constexpr int J = J_;
|
|
514
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
515
|
+
|
|
516
|
+
// RDNA3
|
|
517
|
+
static constexpr int ne = I * J / 32 * 2;
|
|
518
|
+
|
|
519
|
+
T x[ne] = {0};
|
|
520
|
+
|
|
521
|
+
static constexpr __device__ bool supported() {
|
|
522
|
+
if (I == 16 && J == 16) return true;
|
|
523
|
+
if (I == 16 && J == 8) return true;
|
|
524
|
+
if (I == 16 && J == 4) return true;
|
|
525
|
+
return false;
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
|
529
|
+
if constexpr (supported()) {
|
|
530
|
+
return threadIdx.x % 16;
|
|
197
531
|
} else {
|
|
198
|
-
|
|
532
|
+
NO_DEVICE_CODE;
|
|
533
|
+
return -1;
|
|
534
|
+
}
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
538
|
+
if constexpr (supported()) {
|
|
539
|
+
return l;
|
|
540
|
+
} else {
|
|
541
|
+
NO_DEVICE_CODE;
|
|
542
|
+
return -1;
|
|
199
543
|
}
|
|
200
544
|
}
|
|
201
545
|
};
|
|
202
546
|
|
|
547
|
+
template <int I_, int J_>
|
|
548
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
549
|
+
static constexpr int I = I_;
|
|
550
|
+
static constexpr int J = J_;
|
|
551
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
552
|
+
#if defined(RDNA3)
|
|
553
|
+
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
|
554
|
+
|
|
555
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
556
|
+
|
|
557
|
+
static constexpr __device__ bool supported() {
|
|
558
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
562
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
566
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
|
567
|
+
}
|
|
568
|
+
#else // Volta
|
|
569
|
+
static constexpr int ne = I * J / (WARP_SIZE/4);
|
|
570
|
+
|
|
571
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
572
|
+
|
|
573
|
+
static constexpr __device__ bool supported() {
|
|
574
|
+
if (I == 8 && J == 4) return true;
|
|
575
|
+
return false;
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
|
579
|
+
if constexpr (I == 8 && J == 4) {
|
|
580
|
+
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
|
581
|
+
} else {
|
|
582
|
+
NO_DEVICE_CODE;
|
|
583
|
+
return -1;
|
|
584
|
+
}
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
588
|
+
if constexpr (I == 8 && J == 4) {
|
|
589
|
+
return l;
|
|
590
|
+
} else {
|
|
591
|
+
NO_DEVICE_CODE;
|
|
592
|
+
return -1;
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
#endif // defined(RDNA3)
|
|
596
|
+
};
|
|
597
|
+
|
|
598
|
+
template <int I_, int J_>
|
|
599
|
+
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
600
|
+
static constexpr int I = I_;
|
|
601
|
+
static constexpr int J = J_;
|
|
602
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
603
|
+
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
|
604
|
+
|
|
605
|
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
606
|
+
|
|
607
|
+
static constexpr __device__ bool supported() {
|
|
608
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
612
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
616
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
|
617
|
+
}
|
|
618
|
+
};
|
|
619
|
+
|
|
620
|
+
template <int I_, int J_>
|
|
621
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
|
|
622
|
+
static constexpr int I = I_;
|
|
623
|
+
static constexpr int J = J_;
|
|
624
|
+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
|
|
625
|
+
static constexpr int ne = I * J / (WARP_SIZE/4);
|
|
626
|
+
|
|
627
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
628
|
+
|
|
629
|
+
static constexpr __device__ bool supported() {
|
|
630
|
+
if (I == 8 && J == 4) return true;
|
|
631
|
+
return false;
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
635
|
+
if constexpr (I == 8 && J == 4) {
|
|
636
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
637
|
+
} else {
|
|
638
|
+
NO_DEVICE_CODE;
|
|
639
|
+
return -1;
|
|
640
|
+
}
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
644
|
+
if constexpr (I == 8 && J == 4) {
|
|
645
|
+
return ((threadIdx.x / 16) * 2) + (l % 2);
|
|
646
|
+
} else {
|
|
647
|
+
NO_DEVICE_CODE;
|
|
648
|
+
return -1;
|
|
649
|
+
}
|
|
650
|
+
}
|
|
651
|
+
};
|
|
652
|
+
|
|
653
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
203
654
|
template <int I, int J>
|
|
204
655
|
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
205
656
|
tile<I, J/2, half2> ret;
|
|
@@ -217,9 +668,54 @@ namespace ggml_cuda_mma {
|
|
|
217
668
|
|
|
218
669
|
return ret;
|
|
219
670
|
}
|
|
671
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
672
|
+
template <int I, int J>
|
|
673
|
+
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
674
|
+
tile<I, J/2, half2> ret;
|
|
675
|
+
#pragma unroll
|
|
676
|
+
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
|
|
677
|
+
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
|
678
|
+
}
|
|
679
|
+
return ret;
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
|
|
683
|
+
NO_DEVICE_CODE;
|
|
684
|
+
return tile<8, 8, half2>{};
|
|
685
|
+
}
|
|
686
|
+
#else // Volta
|
|
687
|
+
template <int I, int J>
|
|
688
|
+
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
689
|
+
tile<I, J/2, half2> ret;
|
|
690
|
+
#pragma unroll
|
|
691
|
+
for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
|
|
692
|
+
ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
|
693
|
+
ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
|
|
694
|
+
|
|
695
|
+
// On Volta FP16 and FP32 tiles have a different memory layout,
|
|
696
|
+
// for the conversion threads with an offset of 2 need to exchange half their values:
|
|
697
|
+
ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
|
|
698
|
+
0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
|
|
699
|
+
}
|
|
700
|
+
return ret;
|
|
701
|
+
}
|
|
702
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
220
703
|
|
|
221
|
-
|
|
222
|
-
|
|
704
|
+
static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
|
|
705
|
+
#if defined(RDNA4)
|
|
706
|
+
const int row = t.get_i(0);
|
|
707
|
+
const int left_right = t.get_j(0) / 4;
|
|
708
|
+
const int up_down = row / 8;
|
|
709
|
+
const int idx = row % 8;
|
|
710
|
+
reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
|
|
711
|
+
#else
|
|
712
|
+
GGML_UNUSED_VARS(t);
|
|
713
|
+
NO_DEVICE_CODE;
|
|
714
|
+
#endif // defined(RDNA4)
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
template <int I, int J, typename T, data_layout dl>
|
|
718
|
+
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
|
223
719
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
224
720
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
225
721
|
#pragma unroll
|
|
@@ -227,9 +723,28 @@ namespace ggml_cuda_mma {
|
|
|
227
723
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
228
724
|
}
|
|
229
725
|
} else {
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
726
|
+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
727
|
+
}
|
|
728
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
729
|
+
// All wmma layout has contiguous data when i-major.
|
|
730
|
+
if constexpr (is_i_major(dl)) {
|
|
731
|
+
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
|
|
732
|
+
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
|
|
733
|
+
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
|
|
734
|
+
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
|
|
735
|
+
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
|
|
736
|
+
#pragma unroll
|
|
737
|
+
for (int i = 0; i < aligned_copy_count; ++i) {
|
|
738
|
+
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
|
|
739
|
+
}
|
|
740
|
+
} else {
|
|
741
|
+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
742
|
+
}
|
|
743
|
+
} else {
|
|
744
|
+
#pragma unroll
|
|
745
|
+
for (int l = 0; l < t.ne; ++l) {
|
|
746
|
+
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
747
|
+
}
|
|
233
748
|
}
|
|
234
749
|
#else
|
|
235
750
|
#pragma unroll
|
|
@@ -263,25 +778,63 @@ namespace ggml_cuda_mma {
|
|
|
263
778
|
: "=r"(xi[0]), "=r"(xi[1])
|
|
264
779
|
: "l"(xs));
|
|
265
780
|
#else
|
|
266
|
-
|
|
267
|
-
|
|
781
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
782
|
+
GGML_UNUSED_VARS(t, xs0, stride);
|
|
783
|
+
NO_DEVICE_CODE;
|
|
784
|
+
#else
|
|
785
|
+
load_generic(t, xs0, stride);
|
|
786
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
268
787
|
#endif // TURING_MMA_AVAILABLE
|
|
269
788
|
}
|
|
270
789
|
|
|
271
|
-
template <typename T>
|
|
790
|
+
template <typename T, data_layout dl>
|
|
272
791
|
static __device__ __forceinline__ void load_ldmatrix(
|
|
273
|
-
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
792
|
+
tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
|
274
793
|
#if defined(TURING_MMA_AVAILABLE)
|
|
275
794
|
int * xi = (int * ) t.x;
|
|
276
795
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
|
277
796
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
|
278
797
|
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
|
279
798
|
: "l"(xs));
|
|
799
|
+
#else
|
|
800
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
801
|
+
#if 1
|
|
802
|
+
// TODO: more generic handling
|
|
803
|
+
static_assert(sizeof(T) == 4, "bad type size");
|
|
804
|
+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
|
|
805
|
+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
|
|
806
|
+
#else
|
|
807
|
+
load_generic(t, xs0, stride);
|
|
808
|
+
#endif // 1
|
|
280
809
|
#else
|
|
281
810
|
load_generic(t, xs0, stride);
|
|
811
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
282
812
|
#endif // TURING_MMA_AVAILABLE
|
|
283
813
|
}
|
|
284
814
|
|
|
815
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
816
|
+
tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
817
|
+
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
821
|
+
tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
822
|
+
#pragma unroll
|
|
823
|
+
for (int l0 = 0; l0 < t.ne; l0 += 2) {
|
|
824
|
+
ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
|
|
825
|
+
}
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
829
|
+
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
830
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
831
|
+
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
|
832
|
+
#else
|
|
833
|
+
GGML_UNUSED_VARS(t, xs0, stride);
|
|
834
|
+
NO_DEVICE_CODE;
|
|
835
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
836
|
+
}
|
|
837
|
+
|
|
285
838
|
template <typename T>
|
|
286
839
|
static __device__ __forceinline__ void load_ldmatrix_trans(
|
|
287
840
|
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
@@ -400,14 +953,54 @@ namespace ggml_cuda_mma {
|
|
|
400
953
|
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
401
954
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
|
402
955
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
956
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
957
|
+
#if defined(RDNA4)
|
|
958
|
+
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
|
959
|
+
halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]);
|
|
960
|
+
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
|
961
|
+
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
|
962
|
+
acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
|
963
|
+
#else
|
|
964
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
965
|
+
NO_DEVICE_CODE;
|
|
966
|
+
#endif // defined(RDNA4)
|
|
967
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
968
|
+
// MFMA: FP16 input, FP32 accumulate, convert back to half2.
|
|
969
|
+
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
|
970
|
+
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
|
971
|
+
|
|
972
|
+
// Convert existing half2 accumulator to float for MFMA:
|
|
973
|
+
floatx4_t acc_f32;
|
|
974
|
+
{
|
|
975
|
+
const halfx4_t acc_h = reinterpret_cast<const halfx4_t&>(D.x[0]);
|
|
976
|
+
#pragma unroll
|
|
977
|
+
for (int i = 0; i < 4; ++i) {
|
|
978
|
+
acc_f32[i] = (float)acc_h[i];
|
|
979
|
+
}
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
|
983
|
+
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
|
984
|
+
acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);
|
|
985
|
+
|
|
986
|
+
// Convert back to half2:
|
|
987
|
+
{
|
|
988
|
+
halfx4_t result_h;
|
|
989
|
+
#pragma unroll
|
|
990
|
+
for (int i = 0; i < 4; ++i) {
|
|
991
|
+
result_h[i] = (_Float16)acc_f32[i];
|
|
992
|
+
}
|
|
993
|
+
reinterpret_cast<halfx4_t&>(D.x[0]) = result_h;
|
|
994
|
+
}
|
|
403
995
|
#else
|
|
404
996
|
GGML_UNUSED_VARS(D, A, B);
|
|
405
997
|
NO_DEVICE_CODE;
|
|
406
998
|
#endif // TURING_MMA_AVAILABLE
|
|
407
999
|
}
|
|
408
1000
|
|
|
1001
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
409
1002
|
static __device__ __forceinline__ void mma(
|
|
410
|
-
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
|
|
1003
|
+
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
|
|
411
1004
|
#ifdef AMPERE_MMA_AVAILABLE
|
|
412
1005
|
const int * Axi = (const int *) A.x;
|
|
413
1006
|
const int * Bxi = (const int *) B.x;
|
|
@@ -421,6 +1014,53 @@ namespace ggml_cuda_mma {
|
|
|
421
1014
|
#endif // AMPERE_MMA_AVAILABLE
|
|
422
1015
|
}
|
|
423
1016
|
|
|
1017
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
1018
|
+
static __device__ __forceinline__ void mma(
|
|
1019
|
+
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
|
|
1020
|
+
#ifdef AMD_MFMA_AVAILABLE
|
|
1021
|
+
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
|
1022
|
+
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
|
1023
|
+
#if defined(CDNA3)
|
|
1024
|
+
using floatx2_t = __attribute__((ext_vector_type(2))) float;
|
|
1025
|
+
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
|
|
1026
|
+
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
|
|
1027
|
+
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
|
1028
|
+
#elif defined(CDNA2) || defined(CDNA1)
|
|
1029
|
+
#pragma unroll
|
|
1030
|
+
for (int i = 0; i < 2; ++i) {
|
|
1031
|
+
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
|
|
1032
|
+
}
|
|
1033
|
+
#else
|
|
1034
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1035
|
+
NO_DEVICE_CODE;
|
|
1036
|
+
#endif // defined(CDNA3)
|
|
1037
|
+
#else
|
|
1038
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1039
|
+
NO_DEVICE_CODE;
|
|
1040
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
1041
|
+
}
|
|
1042
|
+
|
|
1043
|
+
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
|
1044
|
+
const tile<16, 8, int> & A,
|
|
1045
|
+
const tile<8, 8, int> & B,
|
|
1046
|
+
uint32_t a_scale,
|
|
1047
|
+
uint32_t b_scale) {
|
|
1048
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
1049
|
+
const int * Axi = (const int *) A.x;
|
|
1050
|
+
const int * Bxi = (const int *) B.x;
|
|
1051
|
+
float * Dxi = (float *) D.x;
|
|
1052
|
+
|
|
1053
|
+
asm volatile(
|
|
1054
|
+
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
|
1055
|
+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
|
1056
|
+
"%10, {0, 0}, %11, {0, 0};"
|
|
1057
|
+
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
|
1058
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
|
1059
|
+
#else
|
|
1060
|
+
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
|
|
1061
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
1062
|
+
}
|
|
1063
|
+
|
|
424
1064
|
static __device__ __forceinline__ void mma(
|
|
425
1065
|
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
|
426
1066
|
#ifdef TURING_MMA_AVAILABLE
|
|
@@ -461,8 +1101,9 @@ namespace ggml_cuda_mma {
|
|
|
461
1101
|
#endif // AMPERE_MMA_AVAILABLE
|
|
462
1102
|
}
|
|
463
1103
|
|
|
1104
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
464
1105
|
static __device__ __forceinline__ void mma(
|
|
465
|
-
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
|
1106
|
+
tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
|
|
466
1107
|
#ifdef TURING_MMA_AVAILABLE
|
|
467
1108
|
const int * Axi = (const int *) A.x;
|
|
468
1109
|
const int * Bxi = (const int *) B.x;
|
|
@@ -489,14 +1130,89 @@ namespace ggml_cuda_mma {
|
|
|
489
1130
|
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
490
1131
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
|
491
1132
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
1133
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
1134
|
+
#if defined(RDNA4)
|
|
1135
|
+
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
|
1136
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1137
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1138
|
+
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
|
1139
|
+
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
|
1140
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
|
1141
|
+
#elif defined(RDNA3)
|
|
1142
|
+
using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
|
|
1143
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1144
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1145
|
+
const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
|
|
1146
|
+
const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
|
|
1147
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
|
|
1148
|
+
#else
|
|
1149
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1150
|
+
NO_DEVICE_CODE;
|
|
1151
|
+
#endif // RDNA4
|
|
1152
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
1153
|
+
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
|
1154
|
+
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
|
1155
|
+
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
|
1156
|
+
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
|
1157
|
+
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
|
1158
|
+
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
|
492
1159
|
#else
|
|
493
1160
|
GGML_UNUSED_VARS(D, A, B);
|
|
494
1161
|
NO_DEVICE_CODE;
|
|
495
1162
|
#endif // TURING_MMA_AVAILABLE
|
|
496
1163
|
}
|
|
497
1164
|
|
|
1165
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
1166
|
+
static __device__ __forceinline__ void mma(
|
|
1167
|
+
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
|
|
1168
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1169
|
+
#if defined(RDNA4)
|
|
1170
|
+
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
|
1171
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1172
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1173
|
+
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
|
|
1174
|
+
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
|
|
1175
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
|
|
1176
|
+
#elif defined(RDNA3)
|
|
1177
|
+
using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
|
|
1178
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1179
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1180
|
+
const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
|
|
1181
|
+
const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
|
|
1182
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
|
|
1183
|
+
#else
|
|
1184
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1185
|
+
NO_DEVICE_CODE;
|
|
1186
|
+
#endif // defined(RDNA4)
|
|
1187
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
1188
|
+
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
|
1189
|
+
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
|
1190
|
+
#if defined(CDNA3) || defined(CDNA2)
|
|
1191
|
+
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
|
|
1192
|
+
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
|
|
1193
|
+
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
|
|
1194
|
+
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
|
|
1195
|
+
#elif defined(CDNA1)
|
|
1196
|
+
#pragma unroll
|
|
1197
|
+
for (int i = 0; i < 2; ++i) {
|
|
1198
|
+
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
|
|
1199
|
+
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
|
|
1200
|
+
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
|
|
1201
|
+
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
|
1202
|
+
}
|
|
1203
|
+
#else
|
|
1204
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1205
|
+
NO_DEVICE_CODE;
|
|
1206
|
+
#endif // defined(CDNA3) || defined(CDNA2)
|
|
1207
|
+
#else
|
|
1208
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1209
|
+
NO_DEVICE_CODE;
|
|
1210
|
+
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
1211
|
+
}
|
|
1212
|
+
|
|
1213
|
+
template <data_layout dl_d, data_layout dl_ab>
|
|
498
1214
|
static __device__ __forceinline__ void mma(
|
|
499
|
-
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
|
1215
|
+
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
|
|
500
1216
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
501
1217
|
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
502
1218
|
int32x4_t * acc = (int32x4_t *) D.x;
|
|
@@ -515,6 +1231,59 @@ namespace ggml_cuda_mma {
|
|
|
515
1231
|
acc[0],
|
|
516
1232
|
0, 0, 0);
|
|
517
1233
|
#endif // defined(CDNA3)
|
|
1234
|
+
|
|
1235
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
1236
|
+
|
|
1237
|
+
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
|
1238
|
+
int32x8_t * acc = (int32x8_t *) D.x;
|
|
1239
|
+
|
|
1240
|
+
#if defined(RDNA4)
|
|
1241
|
+
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
|
1242
|
+
int32x2_t * a_vec = (int32x2_t *) A.x;
|
|
1243
|
+
int32x2_t * b_vec = (int32x2_t *) B.x;
|
|
1244
|
+
|
|
1245
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1246
|
+
true,
|
|
1247
|
+
a_vec[0],
|
|
1248
|
+
true,
|
|
1249
|
+
b_vec[0],
|
|
1250
|
+
acc[0],
|
|
1251
|
+
true
|
|
1252
|
+
);
|
|
1253
|
+
|
|
1254
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1255
|
+
true,
|
|
1256
|
+
a_vec[1],
|
|
1257
|
+
true,
|
|
1258
|
+
b_vec[1],
|
|
1259
|
+
acc[0],
|
|
1260
|
+
true
|
|
1261
|
+
);
|
|
1262
|
+
|
|
1263
|
+
#elif defined(RDNA3)
|
|
1264
|
+
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
1265
|
+
int32x4_t * a_vec = (int32x4_t *) A.x;
|
|
1266
|
+
int32x4_t * b_vec = (int32x4_t *) B.x;
|
|
1267
|
+
|
|
1268
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1269
|
+
true,
|
|
1270
|
+
a_vec[0],
|
|
1271
|
+
true,
|
|
1272
|
+
b_vec[0],
|
|
1273
|
+
acc[0],
|
|
1274
|
+
true
|
|
1275
|
+
);
|
|
1276
|
+
|
|
1277
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1278
|
+
true,
|
|
1279
|
+
a_vec[1],
|
|
1280
|
+
true,
|
|
1281
|
+
b_vec[1],
|
|
1282
|
+
acc[0],
|
|
1283
|
+
true
|
|
1284
|
+
);
|
|
1285
|
+
#endif // RDNA4
|
|
1286
|
+
|
|
518
1287
|
#else
|
|
519
1288
|
GGML_UNUSED_VARS(D, A, B);
|
|
520
1289
|
NO_DEVICE_CODE;
|
|
@@ -541,9 +1310,100 @@ namespace ggml_cuda_mma {
|
|
|
541
1310
|
acc[0],
|
|
542
1311
|
0, 0, 0);
|
|
543
1312
|
#endif // defined(CDNA3)
|
|
1313
|
+
|
|
544
1314
|
#else
|
|
545
1315
|
GGML_UNUSED_VARS(D, A, B);
|
|
546
1316
|
NO_DEVICE_CODE;
|
|
547
1317
|
#endif // AMD_MFMA_AVAILABLE
|
|
548
1318
|
}
|
|
1319
|
+
|
|
1320
|
+
template <typename T1, typename T2, int J, int K>
|
|
1321
|
+
static __device__ __forceinline__ void mma(
|
|
1322
|
+
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
|
|
1323
|
+
tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
|
|
1324
|
+
const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
|
|
1325
|
+
mma(D16[0], A16[0], B);
|
|
1326
|
+
mma(D16[1], A16[1], B);
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
static __device__ __forceinline__ void mma(
|
|
1330
|
+
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
|
1331
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
1332
|
+
const int * Axi = (const int *) A.x;
|
|
1333
|
+
const int * Bxi = (const int *) B.x;
|
|
1334
|
+
int * Dxi = (int *) D.x;
|
|
1335
|
+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
|
1336
|
+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
1337
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
1338
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
1339
|
+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
|
1340
|
+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
1341
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
1342
|
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
|
1343
|
+
#else
|
|
1344
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1345
|
+
NO_DEVICE_CODE;
|
|
1346
|
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
static __device__ __forceinline__ void mma(
|
|
1350
|
+
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
|
|
1351
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
1352
|
+
const int * Axi = (const int *) A.x;
|
|
1353
|
+
const int * Bxi = (const int *) B.x;
|
|
1354
|
+
int * Dxi = (int *) D.x;
|
|
1355
|
+
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
|
1356
|
+
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
|
1357
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
1358
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
1359
|
+
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
|
1360
|
+
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
|
1361
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
1362
|
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
|
1363
|
+
#else
|
|
1364
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1365
|
+
NO_DEVICE_CODE;
|
|
1366
|
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
1367
|
+
}
|
|
1368
|
+
|
|
1369
|
+
template <data_layout dl_d, data_layout dl_ab>
|
|
1370
|
+
static __device__ __forceinline__ void mma(
|
|
1371
|
+
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
|
|
1372
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1373
|
+
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
|
1374
|
+
int32x8_t * acc = (int32x8_t *) D.x;
|
|
1375
|
+
#if defined(RDNA4)
|
|
1376
|
+
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
|
1377
|
+
int32x2_t * a_vec = (int32x2_t *) A.x;
|
|
1378
|
+
int32x2_t * b_vec = (int32x2_t *) B.x;
|
|
1379
|
+
|
|
1380
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1381
|
+
true,
|
|
1382
|
+
a_vec[0],
|
|
1383
|
+
true,
|
|
1384
|
+
b_vec[0],
|
|
1385
|
+
acc[0],
|
|
1386
|
+
false
|
|
1387
|
+
);
|
|
1388
|
+
#elif defined(RDNA3)
|
|
1389
|
+
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
1390
|
+
int32x4_t * a_vec = (int32x4_t *) A.x;
|
|
1391
|
+
int32x4_t * b_vec = (int32x4_t *) B.x;
|
|
1392
|
+
|
|
1393
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1394
|
+
true,
|
|
1395
|
+
a_vec[0],
|
|
1396
|
+
true,
|
|
1397
|
+
b_vec[0],
|
|
1398
|
+
acc[0],
|
|
1399
|
+
false
|
|
1400
|
+
);
|
|
1401
|
+
#endif // RDNA4
|
|
1402
|
+
#else
|
|
1403
|
+
GGML_UNUSED(D);
|
|
1404
|
+
GGML_UNUSED(A);
|
|
1405
|
+
GGML_UNUSED(B);
|
|
1406
|
+
NO_DEVICE_CODE;
|
|
1407
|
+
#endif // AMD_WMMA_AVAILABLE
|
|
1408
|
+
}
|
|
549
1409
|
}
|