whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#pragma once
|
|
1
2
|
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
|
|
2
3
|
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
|
|
3
4
|
// The documentation for the PTX instructions can be found under:
|
|
@@ -12,23 +13,28 @@
|
|
|
12
13
|
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
|
13
14
|
// All matrix tiles have ne physical 32 bit elements per warp.
|
|
14
15
|
//
|
|
15
|
-
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
|
16
|
+
// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
|
17
|
+
// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
|
|
16
18
|
|
|
17
19
|
#include "common.cuh"
|
|
18
20
|
|
|
21
|
+
// On Volta each warp is doing 4 8x8 mma operations in parallel.
|
|
22
|
+
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
|
|
23
|
+
// However, the i indices in this file are by default permuted to simplify the index calculations.
|
|
24
|
+
// #define GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
19
25
|
|
|
20
26
|
#if CUDART_VERSION >= 11080
|
|
21
27
|
|
|
22
28
|
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
|
23
29
|
int ret = 0;
|
|
24
30
|
|
|
25
|
-
#ifdef
|
|
31
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
26
32
|
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
|
27
33
|
: "=r"(ret) : "r"(x));
|
|
28
34
|
#else
|
|
29
35
|
GGML_UNUSED(x);
|
|
30
36
|
NO_DEVICE_CODE;
|
|
31
|
-
#endif // defined(
|
|
37
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
32
38
|
return ret;
|
|
33
39
|
}
|
|
34
40
|
|
|
@@ -62,22 +68,187 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
|
|
62
68
|
|
|
63
69
|
namespace ggml_cuda_mma {
|
|
64
70
|
|
|
71
|
+
// Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
|
|
72
|
+
// effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
|
|
73
|
+
// In those cases the data can be split in different ways across the warp.
|
|
74
|
+
enum data_layout {
|
|
75
|
+
// By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
|
|
76
|
+
// For the A/C matrices this means I major == row major, J major == column major.
|
|
77
|
+
// For the B matrix this means I major == column major, J major == row major.
|
|
78
|
+
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
|
79
|
+
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
|
|
80
|
+
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
|
81
|
+
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
|
|
82
|
+
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
|
83
|
+
};
|
|
84
|
+
// Implemented mma combinations are:
|
|
85
|
+
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
|
86
|
+
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
|
87
|
+
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
|
88
|
+
|
|
89
|
+
static constexpr bool is_i_major(const data_layout dl) {
|
|
90
|
+
return dl == DATA_LAYOUT_I_MAJOR ||
|
|
91
|
+
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
static constexpr __device__ data_layout get_input_data_layout() {
|
|
95
|
+
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
96
|
+
return DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
97
|
+
#else
|
|
98
|
+
return DATA_LAYOUT_I_MAJOR;
|
|
99
|
+
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
|
103
|
+
struct tile {};
|
|
104
|
+
|
|
65
105
|
template <int I_, int J_, typename T>
|
|
66
|
-
struct tile {
|
|
67
|
-
static constexpr int
|
|
68
|
-
static constexpr int
|
|
69
|
-
static constexpr
|
|
106
|
+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
|
|
107
|
+
static constexpr int I = I_;
|
|
108
|
+
static constexpr int J = J_;
|
|
109
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
110
|
+
|
|
111
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
112
|
+
static constexpr int ne = I * J / 64;
|
|
70
113
|
T x[ne] = {0};
|
|
71
114
|
|
|
115
|
+
static constexpr __device__ bool supported() {
|
|
116
|
+
if (I == 64 && J == 2) return true;
|
|
117
|
+
if (I == 16 && J == 8) return true;
|
|
118
|
+
if (I == 32 && J == 4) return true;
|
|
119
|
+
if (I == 16 && J == 16) return true;
|
|
120
|
+
if (I == 32 && J == 32) return true;
|
|
121
|
+
return false;
|
|
122
|
+
}
|
|
123
|
+
|
|
72
124
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
73
|
-
if constexpr (I ==
|
|
125
|
+
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
126
|
+
return threadIdx.x % 16;
|
|
127
|
+
} else if constexpr (I == 16 && J == 8) {
|
|
128
|
+
return threadIdx.x % 16;
|
|
129
|
+
} else if constexpr (I == 32 && J == 4) {
|
|
130
|
+
return threadIdx.x % 32;
|
|
131
|
+
} else if constexpr (I == 16 && J == 16) {
|
|
132
|
+
return threadIdx.x % 16;
|
|
133
|
+
} else if constexpr (I == 32 && J == 32) {
|
|
134
|
+
return threadIdx.x % 32;
|
|
135
|
+
} else {
|
|
136
|
+
NO_DEVICE_CODE;
|
|
137
|
+
return -1;
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
142
|
+
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
143
|
+
return (2 * ((threadIdx.x / 16) % 2) + l);
|
|
144
|
+
} else if constexpr (I == 16 && J == 8) {
|
|
145
|
+
return 2 * (threadIdx.x / 16) + l;
|
|
146
|
+
} else if constexpr (I == 32 && J == 4) {
|
|
147
|
+
return 2 * (threadIdx.x / 32) + l;
|
|
148
|
+
} else if constexpr (I == 16 && J == 16) {
|
|
149
|
+
return 4 * (threadIdx.x / 16) + l;
|
|
150
|
+
} else if constexpr (I == 32 && J == 32) {
|
|
151
|
+
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
|
152
|
+
} else {
|
|
153
|
+
NO_DEVICE_CODE;
|
|
154
|
+
return -1;
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
158
|
+
static constexpr int ne = I * J / 32;
|
|
159
|
+
T x[ne] = {0};
|
|
160
|
+
|
|
161
|
+
static constexpr __device__ bool supported() {
|
|
162
|
+
if (I == 32 && J == 8) return true;
|
|
163
|
+
return false;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
167
|
+
if constexpr (I == 32 && J == 8) {
|
|
168
|
+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
169
|
+
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
|
|
170
|
+
#else
|
|
171
|
+
return (l & 2) + (threadIdx.x & ~2);
|
|
172
|
+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
173
|
+
} else {
|
|
174
|
+
NO_DEVICE_CODE;
|
|
175
|
+
return -1;
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
180
|
+
if constexpr (I == 32 && J == 8) {
|
|
181
|
+
return (threadIdx.x & 2) + (l & (4 + 1));
|
|
182
|
+
} else {
|
|
183
|
+
NO_DEVICE_CODE;
|
|
184
|
+
return -1;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
188
|
+
static constexpr int ne = I * J / 32;
|
|
189
|
+
T x[ne] = {0};
|
|
190
|
+
|
|
191
|
+
static constexpr __device__ bool supported() {
|
|
192
|
+
if (I == 16 && J == 16) return true;
|
|
193
|
+
if (I == 16 && J == 8) return true;
|
|
194
|
+
if (I == 16 && J == 4) return true;
|
|
195
|
+
return false;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
199
|
+
if constexpr (supported()) {
|
|
200
|
+
return threadIdx.x % 16;
|
|
201
|
+
} else {
|
|
202
|
+
NO_DEVICE_CODE;
|
|
203
|
+
return -1;
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
208
|
+
if constexpr (I == 16 && J == 16) {
|
|
209
|
+
// matrix C
|
|
210
|
+
#if defined(RDNA3)
|
|
211
|
+
return 2 * l + (threadIdx.x / 16);
|
|
212
|
+
#else
|
|
213
|
+
return ne * (threadIdx.x / 16) + l;
|
|
214
|
+
#endif // defined(RDNA3)
|
|
215
|
+
} else if constexpr (I == 16 && J == 8) {
|
|
216
|
+
// mmq input for RDNA4
|
|
217
|
+
return ne * (threadIdx.x / 16) + l;
|
|
218
|
+
} else if constexpr (I == 16 && J == 4) {
|
|
219
|
+
return ne * (threadIdx.x / 16) + l;
|
|
220
|
+
} else {
|
|
221
|
+
NO_DEVICE_CODE;
|
|
222
|
+
return -1;
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
#else
|
|
226
|
+
static constexpr int ne = I * J / 32;
|
|
227
|
+
T x[ne] = {0};
|
|
228
|
+
|
|
229
|
+
static constexpr __device__ bool supported() {
|
|
230
|
+
if (I == 8 && J == 4) return true;
|
|
231
|
+
if (I == 8 && J == 8) return true;
|
|
232
|
+
if (I == 16 && J == 8) return true;
|
|
233
|
+
if (I == 16 && J == 16) return true;
|
|
234
|
+
if (I == 32 && J == 8) return true;
|
|
235
|
+
return false;
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
239
|
+
if constexpr (I == 8 && J == 4) {
|
|
240
|
+
return threadIdx.x / 4;
|
|
241
|
+
} else if constexpr (I == 8 && J == 8) {
|
|
74
242
|
return threadIdx.x / 4;
|
|
75
243
|
} else if constexpr (I == 16 && J == 8) {
|
|
76
|
-
return (l / 2) * 8 + threadIdx.x / 4;
|
|
244
|
+
return ((l / 2) * 8) + (threadIdx.x / 4);
|
|
77
245
|
} else if constexpr (I == 16 && J == 16) {
|
|
78
|
-
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
|
246
|
+
return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
|
|
247
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
248
|
+
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
|
|
79
249
|
} else {
|
|
80
|
-
|
|
250
|
+
NO_DEVICE_CODE;
|
|
251
|
+
return -1;
|
|
81
252
|
}
|
|
82
253
|
}
|
|
83
254
|
|
|
@@ -85,49 +256,354 @@ namespace ggml_cuda_mma {
|
|
|
85
256
|
if constexpr (I == 8 && J == 4) {
|
|
86
257
|
return threadIdx.x % 4;
|
|
87
258
|
} else if constexpr (I == 8 && J == 8) {
|
|
88
|
-
return
|
|
259
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
89
260
|
} else if constexpr (I == 16 && J == 8) {
|
|
90
|
-
return
|
|
261
|
+
return ((threadIdx.x % 4) * 2) + (l % 2);
|
|
91
262
|
} else if constexpr (I == 16 && J == 16) {
|
|
92
|
-
return
|
|
263
|
+
return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
|
|
264
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
265
|
+
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
|
|
93
266
|
} else {
|
|
94
|
-
|
|
267
|
+
NO_DEVICE_CODE;
|
|
268
|
+
return -1;
|
|
95
269
|
}
|
|
96
270
|
}
|
|
271
|
+
#endif // defined(GGML_USE_HIP)
|
|
97
272
|
};
|
|
98
273
|
|
|
99
274
|
template <int I_, int J_>
|
|
100
|
-
struct tile<I_, J_, half2> {
|
|
101
|
-
static constexpr int
|
|
102
|
-
static constexpr int
|
|
275
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
|
|
276
|
+
static constexpr int I = I_;
|
|
277
|
+
static constexpr int J = J_;
|
|
278
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
279
|
+
|
|
280
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
281
|
+
static constexpr int ne = I * J / WARP_SIZE;
|
|
282
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
283
|
+
|
|
284
|
+
static constexpr __device__ bool supported() {
|
|
285
|
+
if (I == 32 && J == 4) return true;
|
|
286
|
+
return false;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
290
|
+
if constexpr (I == 32 && J == 4) {
|
|
291
|
+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
292
|
+
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
|
293
|
+
#else
|
|
294
|
+
return threadIdx.x;
|
|
295
|
+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
296
|
+
} else {
|
|
297
|
+
NO_DEVICE_CODE;
|
|
298
|
+
return -1;
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
303
|
+
if constexpr (I == 32 && J == 4) {
|
|
304
|
+
return l;
|
|
305
|
+
} else {
|
|
306
|
+
NO_DEVICE_CODE;
|
|
307
|
+
return -1;
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
311
|
+
static constexpr int ne = I * J / 32;
|
|
312
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
313
|
+
|
|
314
|
+
static constexpr __device__ bool supported() {
|
|
315
|
+
if (I == 16 && J == 8) return true;
|
|
316
|
+
return false;
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
320
|
+
if constexpr (I == 16 && J == 8) {
|
|
321
|
+
return threadIdx.x % 16;
|
|
322
|
+
} else {
|
|
323
|
+
NO_DEVICE_CODE;
|
|
324
|
+
return -1;
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
329
|
+
if constexpr (I == 16 && J == 8) {
|
|
330
|
+
return 4 * (threadIdx.x / 16) + l;
|
|
331
|
+
} else {
|
|
332
|
+
NO_DEVICE_CODE;
|
|
333
|
+
return -1;
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
#else
|
|
103
337
|
static constexpr int ne = I * J / WARP_SIZE;
|
|
104
338
|
half2 x[ne] = {{0.0f, 0.0f}};
|
|
105
339
|
|
|
340
|
+
static constexpr __device__ bool supported() {
|
|
341
|
+
if (I == 8 && J == 4) return true;
|
|
342
|
+
if (I == 8 && J == 8) return true;
|
|
343
|
+
if (I == 16 && J == 8) return true;
|
|
344
|
+
if (I == 16 && J == 16) return true;
|
|
345
|
+
if (I == 32 && J == 8) return true;
|
|
346
|
+
return false;
|
|
347
|
+
}
|
|
348
|
+
|
|
106
349
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
107
350
|
if constexpr (I == 8 && J == 8) {
|
|
108
351
|
return threadIdx.x / 4;
|
|
109
352
|
} else if constexpr (I == 16 && J == 4) {
|
|
110
|
-
return l * 8 + threadIdx.x / 4;
|
|
353
|
+
return (l * 8) + (threadIdx.x / 4);
|
|
111
354
|
} else if constexpr (I == 16 && J == 8) {
|
|
112
|
-
return (l % 2) * 8 + threadIdx.x / 4;
|
|
355
|
+
return ((l % 2) * 8) + (threadIdx.x / 4);
|
|
356
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
357
|
+
return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
|
|
113
358
|
} else {
|
|
114
|
-
|
|
359
|
+
NO_DEVICE_CODE;
|
|
360
|
+
return -1;
|
|
115
361
|
}
|
|
116
362
|
}
|
|
117
363
|
|
|
118
364
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
119
365
|
if constexpr (I == 8 && J == 8) {
|
|
120
|
-
return l * 4 + threadIdx.x % 4;
|
|
366
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
121
367
|
} else if constexpr (I == 16 && J == 4) {
|
|
122
368
|
return threadIdx.x % 4;
|
|
123
369
|
} else if constexpr (I == 16 && J == 8) {
|
|
124
|
-
return (l / 2) * 4 + threadIdx.x % 4;
|
|
370
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
371
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
372
|
+
return ((l & 2) * 2) + (threadIdx.x % 4);
|
|
125
373
|
} else {
|
|
126
|
-
|
|
374
|
+
NO_DEVICE_CODE;
|
|
375
|
+
return -1;
|
|
127
376
|
}
|
|
128
377
|
}
|
|
378
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
129
379
|
};
|
|
130
380
|
|
|
381
|
+
template <int I_, int J_>
|
|
382
|
+
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
|
|
383
|
+
static constexpr int I = I_;
|
|
384
|
+
static constexpr int J = J_;
|
|
385
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
386
|
+
|
|
387
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
388
|
+
static constexpr int ne = I * J / 32;
|
|
389
|
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
390
|
+
|
|
391
|
+
static constexpr __device__ bool supported() {
|
|
392
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
396
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
400
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
|
401
|
+
}
|
|
402
|
+
#else
|
|
403
|
+
static constexpr int ne = I * J / WARP_SIZE;
|
|
404
|
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
405
|
+
|
|
406
|
+
static constexpr __device__ bool supported() {
|
|
407
|
+
if (I == 8 && J == 8) return true;
|
|
408
|
+
if (I == 16 && J == 4) return true;
|
|
409
|
+
if (I == 16 && J == 8) return true;
|
|
410
|
+
return false;
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
414
|
+
if constexpr (I == 8 && J == 8) {
|
|
415
|
+
return threadIdx.x / 4;
|
|
416
|
+
} else if constexpr (I == 16 && J == 4) {
|
|
417
|
+
return (l * 8) + (threadIdx.x / 4);
|
|
418
|
+
} else if constexpr (I == 16 && J == 8) {
|
|
419
|
+
return ((l % 2) * 8) + (threadIdx.x / 4);
|
|
420
|
+
} else {
|
|
421
|
+
NO_DEVICE_CODE;
|
|
422
|
+
return -1;
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
427
|
+
if constexpr (I == 8 && J == 8) {
|
|
428
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
429
|
+
} else if constexpr (I == 16 && J == 4) {
|
|
430
|
+
return threadIdx.x % 4;
|
|
431
|
+
} else if constexpr (I == 16 && J == 8) {
|
|
432
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
433
|
+
} else {
|
|
434
|
+
NO_DEVICE_CODE;
|
|
435
|
+
return -1;
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
439
|
+
};
|
|
440
|
+
|
|
441
|
+
template <int I_, int J_, typename T>
|
|
442
|
+
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
|
|
443
|
+
static constexpr int I = I_;
|
|
444
|
+
static constexpr int J = J_;
|
|
445
|
+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
|
|
446
|
+
|
|
447
|
+
static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
|
|
448
|
+
T x[ne] = {0};
|
|
449
|
+
|
|
450
|
+
static constexpr __device__ bool supported() {
|
|
451
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
455
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
459
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
|
460
|
+
}
|
|
461
|
+
};
|
|
462
|
+
|
|
463
|
+
template <int I_, int J_, typename T>
|
|
464
|
+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
465
|
+
static constexpr int I = I_;
|
|
466
|
+
static constexpr int J = J_;
|
|
467
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
468
|
+
|
|
469
|
+
// RDNA3
|
|
470
|
+
static constexpr int ne = I * J / 32 * 2;
|
|
471
|
+
|
|
472
|
+
T x[ne] = {0};
|
|
473
|
+
|
|
474
|
+
static constexpr __device__ bool supported() {
|
|
475
|
+
if (I == 16 && J == 16) return true;
|
|
476
|
+
if (I == 16 && J == 8) return true;
|
|
477
|
+
if (I == 16 && J == 4) return true;
|
|
478
|
+
return false;
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
|
482
|
+
if constexpr (supported()) {
|
|
483
|
+
return threadIdx.x % 16;
|
|
484
|
+
} else {
|
|
485
|
+
NO_DEVICE_CODE;
|
|
486
|
+
return -1;
|
|
487
|
+
}
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
491
|
+
if constexpr (supported()) {
|
|
492
|
+
return l;
|
|
493
|
+
} else {
|
|
494
|
+
NO_DEVICE_CODE;
|
|
495
|
+
return -1;
|
|
496
|
+
}
|
|
497
|
+
}
|
|
498
|
+
};
|
|
499
|
+
|
|
500
|
+
template <int I_, int J_>
|
|
501
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
502
|
+
static constexpr int I = I_;
|
|
503
|
+
static constexpr int J = J_;
|
|
504
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
505
|
+
#if defined(RDNA3)
|
|
506
|
+
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
|
507
|
+
|
|
508
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
509
|
+
|
|
510
|
+
static constexpr __device__ bool supported() {
|
|
511
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
515
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
519
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
|
520
|
+
}
|
|
521
|
+
#else // Volta
|
|
522
|
+
static constexpr int ne = I * J / (WARP_SIZE/4);
|
|
523
|
+
|
|
524
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
525
|
+
|
|
526
|
+
static constexpr __device__ bool supported() {
|
|
527
|
+
if (I == 8 && J == 4) return true;
|
|
528
|
+
return false;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
|
532
|
+
if constexpr (I == 8 && J == 4) {
|
|
533
|
+
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
|
534
|
+
} else {
|
|
535
|
+
NO_DEVICE_CODE;
|
|
536
|
+
return -1;
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
541
|
+
if constexpr (I == 8 && J == 4) {
|
|
542
|
+
return l;
|
|
543
|
+
} else {
|
|
544
|
+
NO_DEVICE_CODE;
|
|
545
|
+
return -1;
|
|
546
|
+
}
|
|
547
|
+
}
|
|
548
|
+
#endif // defined(RDNA3)
|
|
549
|
+
};
|
|
550
|
+
|
|
551
|
+
template <int I_, int J_>
|
|
552
|
+
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
553
|
+
static constexpr int I = I_;
|
|
554
|
+
static constexpr int J = J_;
|
|
555
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
556
|
+
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
|
557
|
+
|
|
558
|
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
559
|
+
|
|
560
|
+
static constexpr __device__ bool supported() {
|
|
561
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
565
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
569
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
|
570
|
+
}
|
|
571
|
+
};
|
|
572
|
+
|
|
573
|
+
template <int I_, int J_>
|
|
574
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
|
|
575
|
+
static constexpr int I = I_;
|
|
576
|
+
static constexpr int J = J_;
|
|
577
|
+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
|
|
578
|
+
static constexpr int ne = I * J / (WARP_SIZE/4);
|
|
579
|
+
|
|
580
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
581
|
+
|
|
582
|
+
static constexpr __device__ bool supported() {
|
|
583
|
+
if (I == 8 && J == 4) return true;
|
|
584
|
+
return false;
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
588
|
+
if constexpr (I == 8 && J == 4) {
|
|
589
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
590
|
+
} else {
|
|
591
|
+
NO_DEVICE_CODE;
|
|
592
|
+
return -1;
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
597
|
+
if constexpr (I == 8 && J == 4) {
|
|
598
|
+
return ((threadIdx.x / 16) * 2) + (l % 2);
|
|
599
|
+
} else {
|
|
600
|
+
NO_DEVICE_CODE;
|
|
601
|
+
return -1;
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
};
|
|
605
|
+
|
|
606
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
131
607
|
template <int I, int J>
|
|
132
608
|
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
133
609
|
tile<I, J/2, half2> ret;
|
|
@@ -145,19 +621,68 @@ namespace ggml_cuda_mma {
|
|
|
145
621
|
|
|
146
622
|
return ret;
|
|
147
623
|
}
|
|
624
|
+
#else // Volta
|
|
625
|
+
template <int I, int J>
|
|
626
|
+
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
627
|
+
tile<I, J/2, half2> ret;
|
|
628
|
+
#pragma unroll
|
|
629
|
+
for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
|
|
630
|
+
ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
|
631
|
+
ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
|
|
148
632
|
|
|
149
|
-
|
|
150
|
-
|
|
633
|
+
// On Volta FP16 and FP32 tiles have a different memory layout,
|
|
634
|
+
// for the conversion threads with an offset of 2 need to exchange half their values:
|
|
635
|
+
ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
|
|
636
|
+
0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
|
|
637
|
+
}
|
|
638
|
+
return ret;
|
|
639
|
+
}
|
|
640
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
641
|
+
|
|
642
|
+
template <int I, int J, typename T, data_layout dl>
|
|
643
|
+
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
|
644
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
645
|
+
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
646
|
+
#pragma unroll
|
|
647
|
+
for (int l = 0; l < t.ne; ++l) {
|
|
648
|
+
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
649
|
+
}
|
|
650
|
+
} else {
|
|
651
|
+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
652
|
+
}
|
|
653
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
654
|
+
// All wmma layout has contiguous data when i-major.
|
|
655
|
+
if constexpr (is_i_major(dl)) {
|
|
656
|
+
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
|
|
657
|
+
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
|
|
658
|
+
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
|
|
659
|
+
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
|
|
660
|
+
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
|
|
661
|
+
#pragma unroll
|
|
662
|
+
for (int i = 0; i < aligned_copy_count; ++i) {
|
|
663
|
+
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
|
|
664
|
+
}
|
|
665
|
+
} else {
|
|
666
|
+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
667
|
+
}
|
|
668
|
+
} else {
|
|
669
|
+
#pragma unroll
|
|
670
|
+
for (int l = 0; l < t.ne; ++l) {
|
|
671
|
+
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
#else
|
|
151
675
|
#pragma unroll
|
|
152
676
|
for (int l = 0; l < t.ne; ++l) {
|
|
153
677
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
154
678
|
}
|
|
679
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
155
680
|
}
|
|
156
681
|
|
|
157
682
|
template <typename T>
|
|
158
683
|
static __device__ __forceinline__ void load_ldmatrix(
|
|
159
684
|
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
160
|
-
#ifdef
|
|
685
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
161
686
|
int * xi = (int *) t.x;
|
|
162
687
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
|
|
163
688
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
@@ -165,58 +690,94 @@ namespace ggml_cuda_mma {
|
|
|
165
690
|
: "l"(xs));
|
|
166
691
|
#else
|
|
167
692
|
load_generic(t, xs0, stride);
|
|
168
|
-
#endif //
|
|
693
|
+
#endif // TURING_MMA_AVAILABLE
|
|
169
694
|
}
|
|
170
695
|
|
|
171
696
|
template <typename T>
|
|
172
697
|
static __device__ __forceinline__ void load_ldmatrix(
|
|
173
698
|
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
174
|
-
#ifdef
|
|
699
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
175
700
|
int * xi = (int *) t.x;
|
|
176
701
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
|
177
702
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
178
703
|
: "=r"(xi[0]), "=r"(xi[1])
|
|
179
704
|
: "l"(xs));
|
|
180
705
|
#else
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
706
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
707
|
+
GGML_UNUSED_VARS(t, xs0, stride);
|
|
708
|
+
NO_DEVICE_CODE;
|
|
709
|
+
#else
|
|
710
|
+
load_generic(t, xs0, stride);
|
|
711
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
712
|
+
#endif // TURING_MMA_AVAILABLE
|
|
184
713
|
}
|
|
185
714
|
|
|
186
|
-
template <typename T>
|
|
715
|
+
template <typename T, data_layout dl>
|
|
187
716
|
static __device__ __forceinline__ void load_ldmatrix(
|
|
188
|
-
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
189
|
-
#
|
|
717
|
+
tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
|
718
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
190
719
|
int * xi = (int * ) t.x;
|
|
191
720
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
|
192
721
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
|
193
722
|
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
|
194
723
|
: "l"(xs));
|
|
724
|
+
#else
|
|
725
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
726
|
+
#if 1
|
|
727
|
+
// TODO: more generic handling
|
|
728
|
+
static_assert(sizeof(T) == 4, "bad type size");
|
|
729
|
+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
|
|
730
|
+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
|
|
195
731
|
#else
|
|
196
732
|
load_generic(t, xs0, stride);
|
|
197
|
-
#endif //
|
|
733
|
+
#endif // 1
|
|
734
|
+
#else
|
|
735
|
+
load_generic(t, xs0, stride);
|
|
736
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
737
|
+
#endif // TURING_MMA_AVAILABLE
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
741
|
+
tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
742
|
+
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
746
|
+
tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
747
|
+
#pragma unroll
|
|
748
|
+
for (int l0 = 0; l0 < t.ne; l0 += 2) {
|
|
749
|
+
ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
|
|
750
|
+
}
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
754
|
+
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
755
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
756
|
+
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
|
757
|
+
#else
|
|
758
|
+
GGML_UNUSED_VARS(t, xs0, stride);
|
|
759
|
+
NO_DEVICE_CODE;
|
|
760
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
198
761
|
}
|
|
199
762
|
|
|
200
763
|
template <typename T>
|
|
201
764
|
static __device__ __forceinline__ void load_ldmatrix_trans(
|
|
202
765
|
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
203
|
-
#ifdef
|
|
766
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
204
767
|
int * xi = (int * ) t.x;
|
|
205
768
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
|
206
769
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
|
207
770
|
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
|
|
208
771
|
: "l"(xs));
|
|
209
772
|
#else
|
|
210
|
-
|
|
211
|
-
GGML_UNUSED(xs0);
|
|
212
|
-
GGML_UNUSED(stride);
|
|
773
|
+
GGML_UNUSED_VARS(t, xs0, stride);
|
|
213
774
|
NO_DEVICE_CODE;
|
|
214
|
-
#endif //
|
|
775
|
+
#endif // TURING_MMA_AVAILABLE
|
|
215
776
|
}
|
|
216
777
|
|
|
217
778
|
static __device__ __forceinline__ void mma(
|
|
218
779
|
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
|
|
219
|
-
#ifdef
|
|
780
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
220
781
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
221
782
|
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
|
222
783
|
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
|
@@ -231,16 +792,14 @@ namespace ggml_cuda_mma {
|
|
|
231
792
|
: "r"(A.x[1]), "r"(B.x[0]));
|
|
232
793
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
233
794
|
#else
|
|
234
|
-
|
|
235
|
-
GGML_UNUSED(A);
|
|
236
|
-
GGML_UNUSED(B);
|
|
795
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
237
796
|
NO_DEVICE_CODE;
|
|
238
|
-
#endif //
|
|
797
|
+
#endif // TURING_MMA_AVAILABLE
|
|
239
798
|
}
|
|
240
799
|
|
|
241
800
|
static __device__ __forceinline__ void mma(
|
|
242
801
|
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
|
|
243
|
-
#ifdef
|
|
802
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
244
803
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
245
804
|
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
|
246
805
|
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
|
@@ -261,16 +820,14 @@ namespace ggml_cuda_mma {
|
|
|
261
820
|
: "r"(A.x[3]), "r"(B.x[1]));
|
|
262
821
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
263
822
|
#else
|
|
264
|
-
|
|
265
|
-
GGML_UNUSED(A);
|
|
266
|
-
GGML_UNUSED(B);
|
|
823
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
267
824
|
NO_DEVICE_CODE;
|
|
268
|
-
#endif //
|
|
825
|
+
#endif // TURING_MMA_AVAILABLE
|
|
269
826
|
}
|
|
270
827
|
|
|
271
828
|
static __device__ __forceinline__ void mma(
|
|
272
829
|
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
|
273
|
-
#ifdef
|
|
830
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
274
831
|
const int * Axi = (const int *) A.x;
|
|
275
832
|
const int * Bxi = (const int *) B.x;
|
|
276
833
|
int * Dxi = (int *) D.x;
|
|
@@ -288,16 +845,14 @@ namespace ggml_cuda_mma {
|
|
|
288
845
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
|
289
846
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
290
847
|
#else
|
|
291
|
-
|
|
292
|
-
GGML_UNUSED(A);
|
|
293
|
-
GGML_UNUSED(B);
|
|
848
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
294
849
|
NO_DEVICE_CODE;
|
|
295
|
-
#endif //
|
|
850
|
+
#endif // TURING_MMA_AVAILABLE
|
|
296
851
|
}
|
|
297
852
|
|
|
298
853
|
static __device__ __forceinline__ void mma(
|
|
299
854
|
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
|
300
|
-
#ifdef
|
|
855
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
301
856
|
const int * Axi = (const int *) A.x;
|
|
302
857
|
const int * Bxi = (const int *) B.x;
|
|
303
858
|
int * Dxi = (int *) D.x;
|
|
@@ -324,16 +879,51 @@ namespace ggml_cuda_mma {
|
|
|
324
879
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
|
325
880
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
326
881
|
#else
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
882
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
883
|
+
NO_DEVICE_CODE;
|
|
884
|
+
#endif // TURING_MMA_AVAILABLE
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
888
|
+
static __device__ __forceinline__ void mma(
|
|
889
|
+
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
|
|
890
|
+
#ifdef AMPERE_MMA_AVAILABLE
|
|
891
|
+
const int * Axi = (const int *) A.x;
|
|
892
|
+
const int * Bxi = (const int *) B.x;
|
|
893
|
+
int * Dxi = (int *) D.x;
|
|
894
|
+
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
|
895
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
896
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
897
|
+
#else
|
|
898
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
330
899
|
NO_DEVICE_CODE;
|
|
331
|
-
#endif //
|
|
900
|
+
#endif // AMPERE_MMA_AVAILABLE
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
|
904
|
+
const tile<16, 8, int> & A,
|
|
905
|
+
const tile<8, 8, int> & B,
|
|
906
|
+
uint32_t a_scale,
|
|
907
|
+
uint32_t b_scale) {
|
|
908
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
909
|
+
const int * Axi = (const int *) A.x;
|
|
910
|
+
const int * Bxi = (const int *) B.x;
|
|
911
|
+
float * Dxi = (float *) D.x;
|
|
912
|
+
|
|
913
|
+
asm volatile(
|
|
914
|
+
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
|
915
|
+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
|
916
|
+
"%10, {0, 0}, %11, {0, 0};"
|
|
917
|
+
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
|
918
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
|
919
|
+
#else
|
|
920
|
+
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
|
|
921
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
332
922
|
}
|
|
333
923
|
|
|
334
924
|
static __device__ __forceinline__ void mma(
|
|
335
925
|
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
|
336
|
-
#ifdef
|
|
926
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
337
927
|
const int * Axi = (const int *) A.x;
|
|
338
928
|
const int * Bxi = (const int *) B.x;
|
|
339
929
|
int * Dxi = (int *) D.x;
|
|
@@ -351,16 +941,30 @@ namespace ggml_cuda_mma {
|
|
|
351
941
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
|
352
942
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
353
943
|
#else
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
944
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
945
|
+
NO_DEVICE_CODE;
|
|
946
|
+
#endif // TURING_MMA_AVAILABLE
|
|
947
|
+
}
|
|
948
|
+
|
|
949
|
+
static __device__ __forceinline__ void mma(
|
|
950
|
+
tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
|
|
951
|
+
#ifdef AMPERE_MMA_AVAILABLE
|
|
952
|
+
const int * Axi = (const int *) A.x;
|
|
953
|
+
const int * Bxi = (const int *) B.x;
|
|
954
|
+
int * Dxi = (int *) D.x;
|
|
955
|
+
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
|
956
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
957
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
958
|
+
#else
|
|
959
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
357
960
|
NO_DEVICE_CODE;
|
|
358
|
-
#endif //
|
|
961
|
+
#endif // AMPERE_MMA_AVAILABLE
|
|
359
962
|
}
|
|
360
963
|
|
|
964
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
361
965
|
static __device__ __forceinline__ void mma(
|
|
362
|
-
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
|
363
|
-
#ifdef
|
|
966
|
+
tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
|
|
967
|
+
#ifdef TURING_MMA_AVAILABLE
|
|
364
968
|
const int * Axi = (const int *) A.x;
|
|
365
969
|
const int * Bxi = (const int *) B.x;
|
|
366
970
|
int * Dxi = (int *) D.x;
|
|
@@ -386,11 +990,253 @@ namespace ggml_cuda_mma {
|
|
|
386
990
|
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
387
991
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
|
388
992
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
993
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
994
|
+
#if defined(RDNA4)
|
|
995
|
+
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
|
996
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
997
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
998
|
+
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
|
999
|
+
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
|
1000
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
|
1001
|
+
#elif defined(RDNA3)
|
|
1002
|
+
using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
|
|
1003
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1004
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1005
|
+
const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
|
|
1006
|
+
const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
|
|
1007
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
|
|
1008
|
+
#else
|
|
1009
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1010
|
+
NO_DEVICE_CODE;
|
|
1011
|
+
#endif // RDNA4
|
|
1012
|
+
#else
|
|
1013
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1014
|
+
NO_DEVICE_CODE;
|
|
1015
|
+
#endif // TURING_MMA_AVAILABLE
|
|
1016
|
+
}
|
|
1017
|
+
|
|
1018
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
1019
|
+
static __device__ __forceinline__ void mma(
|
|
1020
|
+
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
|
|
1021
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1022
|
+
#if defined(RDNA4)
|
|
1023
|
+
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
|
1024
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1025
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1026
|
+
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
|
|
1027
|
+
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
|
|
1028
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
|
|
1029
|
+
#elif defined(RDNA3)
|
|
1030
|
+
using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
|
|
1031
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1032
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1033
|
+
const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
|
|
1034
|
+
const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
|
|
1035
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
|
|
1036
|
+
#else
|
|
1037
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1038
|
+
NO_DEVICE_CODE;
|
|
1039
|
+
#endif // RDNA4
|
|
1040
|
+
#else
|
|
1041
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1042
|
+
NO_DEVICE_CODE;
|
|
1043
|
+
#endif // AMPERE_MMA_AVAILABLE
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
template <data_layout dl_d, data_layout dl_ab>
|
|
1047
|
+
static __device__ __forceinline__ void mma(
|
|
1048
|
+
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
|
|
1049
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1050
|
+
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
1051
|
+
int32x4_t * acc = (int32x4_t *) D.x;
|
|
1052
|
+
#if defined(CDNA3)
|
|
1053
|
+
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
|
|
1054
|
+
((int64_t *) B.x)[0],
|
|
1055
|
+
acc[0],
|
|
1056
|
+
0, 0, 0);
|
|
1057
|
+
#elif defined(CDNA2) || defined(CDNA)
|
|
1058
|
+
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
|
|
1059
|
+
B.x[0],
|
|
1060
|
+
acc[0],
|
|
1061
|
+
0, 0, 0);
|
|
1062
|
+
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
|
|
1063
|
+
B.x[1],
|
|
1064
|
+
acc[0],
|
|
1065
|
+
0, 0, 0);
|
|
1066
|
+
#endif // defined(CDNA3)
|
|
1067
|
+
|
|
1068
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
1069
|
+
|
|
1070
|
+
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
|
1071
|
+
int32x8_t * acc = (int32x8_t *) D.x;
|
|
1072
|
+
|
|
1073
|
+
#if defined(RDNA4)
|
|
1074
|
+
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
|
1075
|
+
int32x2_t * a_vec = (int32x2_t *) A.x;
|
|
1076
|
+
int32x2_t * b_vec = (int32x2_t *) B.x;
|
|
1077
|
+
|
|
1078
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1079
|
+
true,
|
|
1080
|
+
a_vec[0],
|
|
1081
|
+
true,
|
|
1082
|
+
b_vec[0],
|
|
1083
|
+
acc[0],
|
|
1084
|
+
true
|
|
1085
|
+
);
|
|
1086
|
+
|
|
1087
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1088
|
+
true,
|
|
1089
|
+
a_vec[1],
|
|
1090
|
+
true,
|
|
1091
|
+
b_vec[1],
|
|
1092
|
+
acc[0],
|
|
1093
|
+
true
|
|
1094
|
+
);
|
|
1095
|
+
|
|
1096
|
+
#elif defined(RDNA3)
|
|
1097
|
+
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
1098
|
+
int32x4_t * a_vec = (int32x4_t *) A.x;
|
|
1099
|
+
int32x4_t * b_vec = (int32x4_t *) B.x;
|
|
1100
|
+
|
|
1101
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1102
|
+
true,
|
|
1103
|
+
a_vec[0],
|
|
1104
|
+
true,
|
|
1105
|
+
b_vec[0],
|
|
1106
|
+
acc[0],
|
|
1107
|
+
true
|
|
1108
|
+
);
|
|
1109
|
+
|
|
1110
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1111
|
+
true,
|
|
1112
|
+
a_vec[1],
|
|
1113
|
+
true,
|
|
1114
|
+
b_vec[1],
|
|
1115
|
+
acc[0],
|
|
1116
|
+
true
|
|
1117
|
+
);
|
|
1118
|
+
#endif // RDNA4
|
|
1119
|
+
|
|
1120
|
+
#else
|
|
1121
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1122
|
+
NO_DEVICE_CODE;
|
|
1123
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
1124
|
+
}
|
|
1125
|
+
|
|
1126
|
+
static __device__ __forceinline__ void mma(
|
|
1127
|
+
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
|
|
1128
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1129
|
+
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
|
|
1130
|
+
int32x16_t * acc = (int32x16_t *) D.x;
|
|
1131
|
+
#if defined(CDNA3)
|
|
1132
|
+
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
|
|
1133
|
+
((int64_t *) B.x)[0],
|
|
1134
|
+
acc[0],
|
|
1135
|
+
0, 0, 0);
|
|
1136
|
+
#elif defined(CDNA2) || defined(CDNA)
|
|
1137
|
+
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
|
|
1138
|
+
B.x[0],
|
|
1139
|
+
acc[0],
|
|
1140
|
+
0, 0, 0);
|
|
1141
|
+
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
|
|
1142
|
+
B.x[1],
|
|
1143
|
+
acc[0],
|
|
1144
|
+
0, 0, 0);
|
|
1145
|
+
#endif // defined(CDNA3)
|
|
1146
|
+
|
|
1147
|
+
#else
|
|
1148
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1149
|
+
NO_DEVICE_CODE;
|
|
1150
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
template <typename T1, typename T2, int J, int K>
|
|
1154
|
+
static __device__ __forceinline__ void mma(
|
|
1155
|
+
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
|
|
1156
|
+
tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
|
|
1157
|
+
const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
|
|
1158
|
+
mma(D16[0], A16[0], B);
|
|
1159
|
+
mma(D16[1], A16[1], B);
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
static __device__ __forceinline__ void mma(
|
|
1163
|
+
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
|
1164
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
1165
|
+
const int * Axi = (const int *) A.x;
|
|
1166
|
+
const int * Bxi = (const int *) B.x;
|
|
1167
|
+
int * Dxi = (int *) D.x;
|
|
1168
|
+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
|
1169
|
+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
1170
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
1171
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
1172
|
+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
|
1173
|
+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
1174
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
1175
|
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
|
1176
|
+
#else
|
|
1177
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1178
|
+
NO_DEVICE_CODE;
|
|
1179
|
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1182
|
+
static __device__ __forceinline__ void mma(
|
|
1183
|
+
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
|
|
1184
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
1185
|
+
const int * Axi = (const int *) A.x;
|
|
1186
|
+
const int * Bxi = (const int *) B.x;
|
|
1187
|
+
int * Dxi = (int *) D.x;
|
|
1188
|
+
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
|
1189
|
+
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
|
1190
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
1191
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
1192
|
+
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
|
1193
|
+
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
|
1194
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
1195
|
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
|
1196
|
+
#else
|
|
1197
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1198
|
+
NO_DEVICE_CODE;
|
|
1199
|
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
1200
|
+
}
|
|
1201
|
+
|
|
1202
|
+
template <data_layout dl_d, data_layout dl_ab>
|
|
1203
|
+
static __device__ __forceinline__ void mma(
|
|
1204
|
+
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
|
|
1205
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1206
|
+
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
|
1207
|
+
int32x8_t * acc = (int32x8_t *) D.x;
|
|
1208
|
+
#if defined(RDNA4)
|
|
1209
|
+
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
|
1210
|
+
int32x2_t * a_vec = (int32x2_t *) A.x;
|
|
1211
|
+
int32x2_t * b_vec = (int32x2_t *) B.x;
|
|
1212
|
+
|
|
1213
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1214
|
+
true,
|
|
1215
|
+
a_vec[0],
|
|
1216
|
+
true,
|
|
1217
|
+
b_vec[0],
|
|
1218
|
+
acc[0],
|
|
1219
|
+
false
|
|
1220
|
+
);
|
|
1221
|
+
#elif defined(RDNA3)
|
|
1222
|
+
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
1223
|
+
int32x4_t * a_vec = (int32x4_t *) A.x;
|
|
1224
|
+
int32x4_t * b_vec = (int32x4_t *) B.x;
|
|
1225
|
+
|
|
1226
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1227
|
+
true,
|
|
1228
|
+
a_vec[0],
|
|
1229
|
+
true,
|
|
1230
|
+
b_vec[0],
|
|
1231
|
+
acc[0],
|
|
1232
|
+
false
|
|
1233
|
+
);
|
|
1234
|
+
#endif // RDNA4
|
|
389
1235
|
#else
|
|
390
1236
|
GGML_UNUSED(D);
|
|
391
1237
|
GGML_UNUSED(A);
|
|
392
1238
|
GGML_UNUSED(B);
|
|
393
1239
|
NO_DEVICE_CODE;
|
|
394
|
-
#endif //
|
|
1240
|
+
#endif // AMD_WMMA_AVAILABLE
|
|
395
1241
|
}
|
|
396
1242
|
}
|