whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
|
|
|
11
11
|
|
|
12
12
|
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
|
13
13
|
#define MMQ_ITER_K 256
|
|
14
|
+
#define MMQ_ITER_K_MXFP4_FP4 512
|
|
14
15
|
#define MMQ_NWARPS 8
|
|
15
16
|
|
|
16
17
|
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
|
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
|
|
|
44
45
|
};
|
|
45
46
|
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
|
46
47
|
};
|
|
48
|
+
|
|
49
|
+
struct block_fp4_mmq {
|
|
50
|
+
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
|
|
51
|
+
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
|
|
52
|
+
};
|
|
53
|
+
|
|
47
54
|
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
|
48
55
|
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
|
56
|
+
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
|
|
49
57
|
|
|
50
58
|
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
51
59
|
switch (type_x) {
|
|
@@ -58,6 +66,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
|
58
66
|
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
|
59
67
|
case GGML_TYPE_Q8_0:
|
|
60
68
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
69
|
+
case GGML_TYPE_MXFP4:
|
|
70
|
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
61
71
|
case GGML_TYPE_Q2_K:
|
|
62
72
|
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
|
63
73
|
case GGML_TYPE_Q3_K:
|
|
@@ -90,7 +100,7 @@ struct tile_x_sizes {
|
|
|
90
100
|
};
|
|
91
101
|
|
|
92
102
|
static int get_mmq_x_max_host(const int cc) {
|
|
93
|
-
return
|
|
103
|
+
return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
|
|
94
104
|
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
|
95
105
|
#ifdef GGML_CUDA_FORCE_MMQ
|
|
96
106
|
128 : 64;
|
|
@@ -100,13 +110,13 @@ static int get_mmq_x_max_host(const int cc) {
|
|
|
100
110
|
}
|
|
101
111
|
|
|
102
112
|
static constexpr __device__ int get_mmq_x_max_device() {
|
|
103
|
-
#
|
|
113
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
104
114
|
return 128;
|
|
105
|
-
#else //
|
|
115
|
+
#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
106
116
|
|
|
107
|
-
#if defined(GGML_USE_HIP)
|
|
108
|
-
return
|
|
109
|
-
#else // defined(GGML_USE_HIP)
|
|
117
|
+
#if defined(GGML_USE_HIP)
|
|
118
|
+
return 64;
|
|
119
|
+
#else // defined(GGML_USE_HIP)
|
|
110
120
|
|
|
111
121
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
112
122
|
#ifdef GGML_CUDA_FORCE_MMQ
|
|
@@ -115,12 +125,11 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|
|
115
125
|
return MMQ_DP4A_MAX_BATCH_SIZE;
|
|
116
126
|
#endif // GGML_CUDA_FORCE_MMQ
|
|
117
127
|
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
118
|
-
|
|
119
128
|
return 64;
|
|
120
129
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
121
130
|
|
|
122
|
-
#endif // defined(GGML_USE_HIP)
|
|
123
|
-
#endif //
|
|
131
|
+
#endif // defined(GGML_USE_HIP)
|
|
132
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
124
133
|
}
|
|
125
134
|
|
|
126
135
|
static int get_mmq_y_host(const int cc) {
|
|
@@ -128,8 +137,16 @@ static int get_mmq_y_host(const int cc) {
|
|
|
128
137
|
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
|
|
129
138
|
}
|
|
130
139
|
|
|
140
|
+
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
|
141
|
+
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
142
|
+
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
|
|
143
|
+
#else
|
|
144
|
+
return MMQ_ITER_K;
|
|
145
|
+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
146
|
+
}
|
|
147
|
+
|
|
131
148
|
static constexpr __device__ int get_mmq_y_device() {
|
|
132
|
-
#if defined(GGML_USE_HIP)
|
|
149
|
+
#if defined(GGML_USE_HIP)
|
|
133
150
|
#if defined(RDNA1)
|
|
134
151
|
return 64;
|
|
135
152
|
#else
|
|
@@ -141,19 +158,28 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
|
141
158
|
#else
|
|
142
159
|
return 64;
|
|
143
160
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
144
|
-
#endif // defined(GGML_USE_HIP)
|
|
161
|
+
#endif // defined(GGML_USE_HIP)
|
|
145
162
|
}
|
|
146
163
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
#define
|
|
155
|
-
|
|
156
|
-
#define
|
|
164
|
+
// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
|
|
165
|
+
// The K dimension of the tiles has either,
|
|
166
|
+
// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
|
|
167
|
+
// 32 bit elements for the quantized data (does not include scales).
|
|
168
|
+
// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
|
|
169
|
+
// The final tile size in K direction is padded to avoid shared memory bank conflicts,
|
|
170
|
+
// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
|
|
171
|
+
#define MMQ_TILE_NE_K 32
|
|
172
|
+
|
|
173
|
+
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
|
|
174
|
+
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
|
|
175
|
+
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
|
|
176
|
+
#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
|
|
177
|
+
#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
|
|
178
|
+
#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
|
|
179
|
+
#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
|
180
|
+
#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
|
181
|
+
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
|
182
|
+
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
|
157
183
|
|
|
158
184
|
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
|
159
185
|
switch (type) {
|
|
@@ -162,6 +188,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
|
162
188
|
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
|
163
189
|
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
|
164
190
|
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
|
191
|
+
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
|
165
192
|
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
|
166
193
|
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
|
167
194
|
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
|
@@ -179,17 +206,20 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
|
179
206
|
}
|
|
180
207
|
}
|
|
181
208
|
|
|
182
|
-
#define MMQ_MMA_TILE_X_K_Q8_0 (2*
|
|
183
|
-
#define
|
|
184
|
-
#define
|
|
185
|
-
#define
|
|
186
|
-
#define
|
|
209
|
+
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
210
|
+
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
|
|
211
|
+
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
212
|
+
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
|
213
|
+
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
|
214
|
+
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
|
187
215
|
|
|
188
216
|
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
|
189
217
|
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
|
190
218
|
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
|
|
191
219
|
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
|
192
220
|
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|
221
|
+
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
|
222
|
+
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
|
|
193
223
|
|
|
194
224
|
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
195
225
|
switch (type) {
|
|
@@ -198,6 +228,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
198
228
|
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
199
229
|
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
200
230
|
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
231
|
+
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
|
232
|
+
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
201
233
|
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
|
202
234
|
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
|
203
235
|
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
@@ -215,42 +247,77 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
215
247
|
}
|
|
216
248
|
}
|
|
217
249
|
|
|
218
|
-
|
|
250
|
+
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
|
|
251
|
+
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
|
|
252
|
+
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
|
|
219
253
|
|
|
220
254
|
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
|
221
|
-
|
|
255
|
+
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
|
|
256
|
+
return mmq_x >= 128 ? 32 : 16;
|
|
257
|
+
} else if (turing_mma_available(cc) && mmq_x >= 48) {
|
|
258
|
+
return 16;
|
|
259
|
+
} else {
|
|
260
|
+
return 8;
|
|
261
|
+
}
|
|
222
262
|
}
|
|
223
263
|
|
|
224
|
-
#
|
|
264
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
265
|
+
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
|
266
|
+
return mmq_x >= 128 ? 32 : 16;
|
|
267
|
+
}
|
|
268
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
|
225
269
|
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
|
226
270
|
return mmq_x >= 48 ? 16 : 8;
|
|
227
271
|
}
|
|
228
272
|
#else
|
|
229
|
-
static constexpr __device__ int mmq_get_granularity_device(const int /*
|
|
273
|
+
static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
|
|
230
274
|
return 8;
|
|
231
275
|
}
|
|
232
|
-
#endif //
|
|
276
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
277
|
+
|
|
278
|
+
#if defined(GGML_USE_HIP)
|
|
279
|
+
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
|
|
280
|
+
return amd_mfma_available(cc) ? 8 : 256/warp_size;
|
|
281
|
+
}
|
|
282
|
+
#else
|
|
283
|
+
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
|
|
284
|
+
return 256/warp_size;
|
|
285
|
+
}
|
|
286
|
+
#endif // (GGML_USE_HIP)
|
|
287
|
+
|
|
288
|
+
static constexpr __device__ int mmq_get_nwarps_device() {
|
|
289
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
290
|
+
return 8;
|
|
291
|
+
#else
|
|
292
|
+
return 256/ggml_cuda_get_physical_warp_size();
|
|
293
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
294
|
+
}
|
|
233
295
|
|
|
234
296
|
// ------------------------------------------------------------
|
|
235
297
|
|
|
236
|
-
template <int mmq_y,
|
|
298
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
|
237
299
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
300
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
301
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
238
302
|
|
|
239
|
-
#
|
|
303
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
240
304
|
int * x_qs = (int *) x_tile;
|
|
241
|
-
float * x_df = (float *) (x_qs + 2*
|
|
305
|
+
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
242
306
|
#else
|
|
243
307
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
|
244
308
|
int * x_qs = (int *) x_tile;
|
|
245
309
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
246
|
-
#endif //
|
|
310
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
247
311
|
|
|
248
|
-
|
|
249
|
-
|
|
312
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
|
|
313
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
314
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
315
|
+
const int kbx = txi / QI4_0;
|
|
316
|
+
const int kqsx = txi % QI4_0;
|
|
250
317
|
|
|
251
318
|
#pragma unroll
|
|
252
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
253
|
-
int i = i0 + threadIdx.y;
|
|
319
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
320
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
254
321
|
|
|
255
322
|
if (need_check) {
|
|
256
323
|
i = min(i, i_max);
|
|
@@ -259,20 +326,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
259
326
|
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
|
260
327
|
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
|
261
328
|
|
|
262
|
-
#
|
|
329
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
263
330
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
|
|
264
331
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
|
|
265
332
|
#else
|
|
266
|
-
x_qs[i*(
|
|
267
|
-
#endif //
|
|
333
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
|
334
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
268
335
|
}
|
|
269
336
|
|
|
270
|
-
|
|
337
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
|
|
338
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
271
339
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
272
340
|
|
|
273
341
|
#pragma unroll
|
|
274
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
275
|
-
int i = i0 + threadIdx.y *
|
|
342
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
343
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
276
344
|
|
|
277
345
|
if (need_check) {
|
|
278
346
|
i = min(i, i_max);
|
|
@@ -280,17 +348,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
280
348
|
|
|
281
349
|
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
|
|
282
350
|
|
|
283
|
-
#
|
|
284
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
351
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
352
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
285
353
|
#else
|
|
286
|
-
x_df[i*(
|
|
287
|
-
#endif //
|
|
354
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
|
355
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
288
356
|
}
|
|
289
357
|
}
|
|
290
358
|
|
|
291
|
-
template <int mmq_x, int mmq_y
|
|
359
|
+
template <int mmq_x, int mmq_y>
|
|
292
360
|
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
293
361
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
362
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
363
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
294
364
|
|
|
295
365
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
|
296
366
|
const int * x_qs = (const int *) x;
|
|
@@ -299,7 +369,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
|
299
369
|
const half2 * y_ds = (const half2 *) y;
|
|
300
370
|
|
|
301
371
|
// #pragma unroll
|
|
302
|
-
for (int k01 = 0; k01 <
|
|
372
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
|
|
303
373
|
const int k0 = k00 + k01;
|
|
304
374
|
|
|
305
375
|
#pragma unroll
|
|
@@ -307,7 +377,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
|
307
377
|
const int j = j0 + threadIdx.y;
|
|
308
378
|
|
|
309
379
|
#pragma unroll
|
|
310
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
380
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
311
381
|
const int i = i0 + threadIdx.x;
|
|
312
382
|
|
|
313
383
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
|
@@ -320,32 +390,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
|
320
390
|
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
|
|
321
391
|
}
|
|
322
392
|
|
|
323
|
-
sum[j0/nwarps*mmq_y/
|
|
324
|
-
(&x_qs[i*(
|
|
325
|
-
x_df[i*(
|
|
393
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
|
394
|
+
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
|
|
395
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
326
396
|
}
|
|
327
397
|
}
|
|
328
398
|
}
|
|
329
399
|
}
|
|
330
400
|
|
|
331
|
-
template <int mmq_y,
|
|
401
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
|
332
402
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
403
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
404
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
333
405
|
|
|
334
|
-
#
|
|
406
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
335
407
|
int * x_qs = (int *) x_tile;
|
|
336
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
|
408
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
337
409
|
#else
|
|
338
410
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
|
339
411
|
int * x_qs = (int *) x_tile;
|
|
340
412
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
341
|
-
#endif //
|
|
413
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
342
414
|
|
|
343
|
-
|
|
344
|
-
|
|
415
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
|
|
416
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
417
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
418
|
+
const int kbx = txi / QI4_1;
|
|
419
|
+
const int kqsx = txi % QI4_1;
|
|
345
420
|
|
|
346
421
|
#pragma unroll
|
|
347
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
348
|
-
int i = i0 + threadIdx.y;
|
|
422
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
423
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
349
424
|
|
|
350
425
|
if (need_check) {
|
|
351
426
|
i = min(i, i_max);
|
|
@@ -354,20 +429,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
354
429
|
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
|
355
430
|
const int qs0 = get_int_b4(bxi->qs, kqsx);
|
|
356
431
|
|
|
357
|
-
#
|
|
432
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
358
433
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
|
359
434
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
|
|
360
435
|
#else
|
|
361
|
-
x_qs[i*(
|
|
362
|
-
#endif //
|
|
436
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
|
437
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
363
438
|
}
|
|
364
439
|
|
|
365
|
-
|
|
440
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
|
|
441
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
366
442
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
367
443
|
|
|
368
444
|
#pragma unroll
|
|
369
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
370
|
-
int i = i0 + threadIdx.y *
|
|
445
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
446
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
371
447
|
|
|
372
448
|
if (need_check) {
|
|
373
449
|
i = min(i, i_max);
|
|
@@ -375,17 +451,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
375
451
|
|
|
376
452
|
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
|
|
377
453
|
|
|
378
|
-
#
|
|
379
|
-
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1
|
|
454
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
455
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
|
380
456
|
#else
|
|
381
|
-
x_dm[i*(
|
|
382
|
-
#endif //
|
|
457
|
+
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
|
|
458
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
383
459
|
}
|
|
384
460
|
}
|
|
385
461
|
|
|
386
|
-
template <int mmq_x, int mmq_y
|
|
462
|
+
template <int mmq_x, int mmq_y>
|
|
387
463
|
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
388
464
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
465
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
466
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
389
467
|
|
|
390
468
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
|
391
469
|
const int * x_qs = (const int *) x;
|
|
@@ -394,7 +472,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
|
394
472
|
const half2 * y_ds = (const half2 *) y;
|
|
395
473
|
|
|
396
474
|
// #pragma unroll
|
|
397
|
-
for (int k01 = 0; k01 <
|
|
475
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
|
|
398
476
|
const int k0 = k00 + k01;
|
|
399
477
|
|
|
400
478
|
#pragma unroll
|
|
@@ -402,7 +480,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
|
402
480
|
const int j = j0 + threadIdx.y;
|
|
403
481
|
|
|
404
482
|
#pragma unroll
|
|
405
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
483
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
406
484
|
const int i = i0 + threadIdx.x;
|
|
407
485
|
|
|
408
486
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
|
@@ -415,32 +493,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
|
415
493
|
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
|
|
416
494
|
}
|
|
417
495
|
|
|
418
|
-
sum[j0/nwarps*mmq_y/
|
|
419
|
-
(&x_qs[i*(
|
|
420
|
-
x_dm[i*(
|
|
496
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
|
497
|
+
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
|
|
498
|
+
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
421
499
|
}
|
|
422
500
|
}
|
|
423
501
|
}
|
|
424
502
|
}
|
|
425
503
|
|
|
426
|
-
template <int mmq_y,
|
|
504
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
|
427
505
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
506
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
507
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
428
508
|
|
|
429
|
-
#
|
|
509
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
430
510
|
int * x_qs = (int *) x_tile;
|
|
431
|
-
float * x_df = (float *) (x_qs +
|
|
511
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
432
512
|
#else
|
|
433
513
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
|
434
514
|
int * x_qs = (int *) x_tile;
|
|
435
515
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
436
|
-
#endif //
|
|
516
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
437
517
|
|
|
438
|
-
|
|
439
|
-
|
|
518
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
|
|
519
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
520
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
521
|
+
const int kbx = txi / QI5_0;
|
|
522
|
+
const int kqsx = txi % QI5_0;
|
|
440
523
|
|
|
441
524
|
#pragma unroll
|
|
442
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
443
|
-
int i = i0 + threadIdx.y;
|
|
525
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
526
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
444
527
|
|
|
445
528
|
if (need_check) {
|
|
446
529
|
i = min(i, i_max);
|
|
@@ -449,7 +532,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
449
532
|
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
|
|
450
533
|
|
|
451
534
|
const int ql = get_int_b2(bxi->qs, kqsx);
|
|
452
|
-
const int qh = get_int_b2(bxi->qh, 0) >> (4 *
|
|
535
|
+
const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
|
|
453
536
|
|
|
454
537
|
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
|
455
538
|
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
|
@@ -465,21 +548,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
465
548
|
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
|
466
549
|
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
|
467
550
|
|
|
468
|
-
#
|
|
551
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
469
552
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
|
470
553
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
|
471
554
|
#else
|
|
472
|
-
x_qs[i*(2*
|
|
473
|
-
x_qs[i*(2*
|
|
474
|
-
#endif //
|
|
555
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
|
556
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
|
557
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
475
558
|
}
|
|
476
559
|
|
|
477
|
-
|
|
560
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
|
|
561
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
478
562
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
479
563
|
|
|
480
564
|
#pragma unroll
|
|
481
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
482
|
-
int i = i0 + threadIdx.y *
|
|
565
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
566
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
483
567
|
|
|
484
568
|
if (need_check) {
|
|
485
569
|
i = min(i, i_max);
|
|
@@ -487,32 +571,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
487
571
|
|
|
488
572
|
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
|
489
573
|
|
|
490
|
-
#
|
|
491
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
574
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
575
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
492
576
|
#else
|
|
493
|
-
x_df[i*(
|
|
494
|
-
#endif //
|
|
577
|
+
x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
|
578
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
495
579
|
}
|
|
496
580
|
}
|
|
497
581
|
|
|
498
|
-
template <int mmq_y,
|
|
582
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
|
499
583
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
584
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
585
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
500
586
|
|
|
501
|
-
#
|
|
587
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
502
588
|
int * x_qs = (int *) x_tile;
|
|
503
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
|
589
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
504
590
|
#else
|
|
505
591
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
|
506
592
|
int * x_qs = (int *) x_tile;
|
|
507
593
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
508
|
-
#endif //
|
|
594
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
509
595
|
|
|
510
|
-
|
|
511
|
-
|
|
596
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
|
|
597
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
598
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
599
|
+
const int kbx = txi / QI5_1;
|
|
600
|
+
const int kqsx = txi % QI5_1;
|
|
512
601
|
|
|
513
602
|
#pragma unroll
|
|
514
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
515
|
-
int i = i0 + threadIdx.y;
|
|
603
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
604
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
516
605
|
|
|
517
606
|
if (need_check) {
|
|
518
607
|
i = min(i, i_max);
|
|
@@ -521,7 +610,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
521
610
|
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
|
|
522
611
|
|
|
523
612
|
const int ql = get_int_b4(bxi->qs, kqsx);
|
|
524
|
-
const int qh = get_int_b4(bxi->qh, 0) >> (4 *
|
|
613
|
+
const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
|
|
525
614
|
|
|
526
615
|
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
|
527
616
|
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
|
@@ -535,21 +624,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
535
624
|
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
|
536
625
|
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
|
537
626
|
|
|
538
|
-
#
|
|
627
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
539
628
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
|
540
629
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
|
541
630
|
#else
|
|
542
|
-
x_qs[i*(2*
|
|
543
|
-
x_qs[i*(2*
|
|
544
|
-
#endif //
|
|
631
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
|
632
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
|
633
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
545
634
|
}
|
|
546
635
|
|
|
547
|
-
|
|
636
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
|
|
637
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
548
638
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
549
639
|
|
|
550
640
|
#pragma unroll
|
|
551
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
552
|
-
int i = i0 + threadIdx.y *
|
|
641
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
642
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
553
643
|
|
|
554
644
|
if (need_check) {
|
|
555
645
|
i = min(i, i_max);
|
|
@@ -557,32 +647,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
557
647
|
|
|
558
648
|
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
|
559
649
|
|
|
560
|
-
#
|
|
561
|
-
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1
|
|
650
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
651
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
|
562
652
|
#else
|
|
563
|
-
x_dm[i*(
|
|
564
|
-
#endif //
|
|
653
|
+
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
|
654
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
565
655
|
}
|
|
566
656
|
}
|
|
567
657
|
|
|
568
|
-
template <int mmq_y,
|
|
658
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
|
569
659
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
660
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
661
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
570
662
|
|
|
571
|
-
#
|
|
663
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
572
664
|
int * x_qs = (int *) x_tile;
|
|
573
|
-
float * x_df = (float *) (x_tile + 2*
|
|
665
|
+
float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
|
|
574
666
|
#else
|
|
575
667
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
|
576
668
|
int * x_qs = (int *) x_tile;
|
|
577
669
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
578
|
-
#endif //
|
|
670
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
579
671
|
|
|
580
|
-
|
|
581
|
-
|
|
672
|
+
// MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
|
|
673
|
+
constexpr int threads_per_row = 32;
|
|
674
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
675
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
676
|
+
const int kbx = txi / QI8_0;
|
|
677
|
+
const int kqsx = txi % QI8_0;
|
|
582
678
|
|
|
583
679
|
#pragma unroll
|
|
584
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
585
|
-
int i = i0 + threadIdx.y;
|
|
680
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
681
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
586
682
|
|
|
587
683
|
if (need_check) {
|
|
588
684
|
i = min(i, i_max);
|
|
@@ -590,21 +686,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
590
686
|
|
|
591
687
|
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
|
592
688
|
|
|
593
|
-
#
|
|
594
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0
|
|
595
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 +
|
|
689
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
690
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
|
691
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
|
596
692
|
#else
|
|
597
|
-
x_qs[i*(2*
|
|
598
|
-
x_qs[i*(2*
|
|
599
|
-
#endif //
|
|
693
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
|
694
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
|
695
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
600
696
|
}
|
|
601
697
|
|
|
602
|
-
|
|
698
|
+
constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
|
|
699
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
603
700
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
604
701
|
|
|
605
702
|
#pragma unroll
|
|
606
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
607
|
-
int i = i0 + threadIdx.y *
|
|
703
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
704
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
608
705
|
|
|
609
706
|
if (need_check) {
|
|
610
707
|
i = min(i, i_max);
|
|
@@ -612,17 +709,128 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
612
709
|
|
|
613
710
|
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
|
614
711
|
|
|
615
|
-
#
|
|
616
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
712
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
713
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
617
714
|
#else
|
|
618
|
-
x_df[i*(2*
|
|
619
|
-
#endif //
|
|
715
|
+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
|
716
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
620
717
|
}
|
|
621
718
|
}
|
|
622
719
|
|
|
623
|
-
template <int
|
|
720
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
|
|
721
|
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
722
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
723
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
724
|
+
|
|
725
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
726
|
+
int * x_qs = (int *) x_tile;
|
|
727
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
728
|
+
#else
|
|
729
|
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
|
|
730
|
+
int * x_qs = (int *) x_tile;
|
|
731
|
+
float * x_df = (float *) (x_qs + txs.qs);
|
|
732
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
733
|
+
|
|
734
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
|
|
735
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
736
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
737
|
+
const int kbx = txi / QI_MXFP4;
|
|
738
|
+
const int kqsx = txi % QI_MXFP4;
|
|
739
|
+
|
|
740
|
+
#pragma unroll
|
|
741
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
742
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
743
|
+
|
|
744
|
+
if (need_check) {
|
|
745
|
+
i = min(i, i_max);
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
|
|
749
|
+
|
|
750
|
+
const int aux_q4 = get_int_b1(bxi->qs, kqsx);
|
|
751
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
|
752
|
+
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
|
|
753
|
+
|
|
754
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
755
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
|
|
756
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
|
|
757
|
+
#else
|
|
758
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
|
759
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
|
|
760
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
|
|
764
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
765
|
+
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
766
|
+
|
|
767
|
+
#pragma unroll
|
|
768
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
769
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
770
|
+
|
|
771
|
+
if (need_check) {
|
|
772
|
+
i = min(i, i_max);
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
|
|
776
|
+
|
|
777
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
778
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
|
779
|
+
#else
|
|
780
|
+
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
|
781
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
782
|
+
}
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
template <int mmq_y, bool need_check>
|
|
786
|
+
static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
|
|
787
|
+
int * __restrict__ x_tile,
|
|
788
|
+
const int kbx0,
|
|
789
|
+
const int i_max,
|
|
790
|
+
const int stride) {
|
|
791
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
792
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
793
|
+
|
|
794
|
+
int * x_qs = (int *) x_tile;
|
|
795
|
+
uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
796
|
+
|
|
797
|
+
const int txi = threadIdx.x;
|
|
798
|
+
|
|
799
|
+
constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
|
|
800
|
+
|
|
801
|
+
constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
|
|
802
|
+
constexpr int rows_per_warp = warp_size / threads_per_row;
|
|
803
|
+
const int kbx = txi % threads_per_row;
|
|
804
|
+
const int row_in_warp = txi / threads_per_row;
|
|
805
|
+
|
|
806
|
+
#pragma unroll
|
|
807
|
+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
|
808
|
+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
|
809
|
+
|
|
810
|
+
if constexpr (need_check) {
|
|
811
|
+
i = min(i, i_max);
|
|
812
|
+
}
|
|
813
|
+
|
|
814
|
+
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
|
|
815
|
+
|
|
816
|
+
// quantize_mxfp4_mmq permutes nibbles to match the quantized format
|
|
817
|
+
const int k0 = kbx * 4;
|
|
818
|
+
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
|
|
819
|
+
|
|
820
|
+
// Load E8M0 scales: pack 2 consecutive scales into one uint32
|
|
821
|
+
if (kbx % 2 == 0) {
|
|
822
|
+
uint32_t e = bxi->e;
|
|
823
|
+
e |= ((bxi + 1)->e << 8);
|
|
824
|
+
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
|
|
825
|
+
}
|
|
826
|
+
}
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
template <int mmq_x, int mmq_y>
|
|
624
830
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
625
831
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
832
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
833
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
626
834
|
|
|
627
835
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
|
628
836
|
const int * x_qs = (const int *) x;
|
|
@@ -631,7 +839,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
|
631
839
|
const float * y_df = (const float *) y;
|
|
632
840
|
|
|
633
841
|
// #pragma unroll
|
|
634
|
-
for (int k01 = 0; k01 <
|
|
842
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
|
|
635
843
|
const int k0 = k00 + k01;
|
|
636
844
|
|
|
637
845
|
#pragma unroll
|
|
@@ -639,21 +847,77 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
|
639
847
|
const int j = j0 + threadIdx.y;
|
|
640
848
|
|
|
641
849
|
#pragma unroll
|
|
642
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
850
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
643
851
|
const int i = i0 + threadIdx.x;
|
|
644
852
|
|
|
645
|
-
sum[j0/nwarps*mmq_y/
|
|
646
|
-
(&x_qs[i*(2*
|
|
647
|
-
x_df[i*(2*
|
|
853
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
|
854
|
+
(&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
|
|
855
|
+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
|
|
648
856
|
}
|
|
649
857
|
}
|
|
650
858
|
}
|
|
651
859
|
}
|
|
652
860
|
|
|
653
|
-
template <int mmq_x, int mmq_y,
|
|
861
|
+
template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
|
|
654
862
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
655
863
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
864
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
865
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
866
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
867
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
868
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
869
|
+
|
|
870
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
871
|
+
constexpr int rows_per_warp = granularity;
|
|
872
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
873
|
+
|
|
874
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
875
|
+
|
|
876
|
+
const int * x_qs = (const int *) x;
|
|
877
|
+
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
|
|
878
|
+
const int * y_qs = (const int *) y + 4;
|
|
879
|
+
const float * y_df = (const float *) y;
|
|
880
|
+
const half2 * y_ds = (const half2 *) y;
|
|
656
881
|
|
|
882
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
883
|
+
|
|
884
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
885
|
+
const int k0 = k00 + k01;
|
|
886
|
+
|
|
887
|
+
tile_A A[ntx];
|
|
888
|
+
#pragma unroll
|
|
889
|
+
for (int n = 0; n < ntx; ++n) {
|
|
890
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
#pragma unroll
|
|
894
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
895
|
+
tile_B B;
|
|
896
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
897
|
+
|
|
898
|
+
float dB;
|
|
899
|
+
const int j = j0 + tile_C::get_j(0);
|
|
900
|
+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
|
901
|
+
dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
902
|
+
} else {
|
|
903
|
+
dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
#pragma unroll
|
|
907
|
+
for (int n = 0; n < ntx; ++n) {
|
|
908
|
+
tile_C C;
|
|
909
|
+
mma(C, A[n], B);
|
|
910
|
+
|
|
911
|
+
#pragma unroll
|
|
912
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
913
|
+
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
|
|
914
|
+
const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
|
915
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
|
|
916
|
+
}
|
|
917
|
+
}
|
|
918
|
+
}
|
|
919
|
+
}
|
|
920
|
+
#else
|
|
657
921
|
typedef tile<16, 8, int> tile_A;
|
|
658
922
|
typedef tile< 8, 8, int> tile_B;
|
|
659
923
|
typedef tile<16, 8, int> tile_C;
|
|
@@ -662,23 +926,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
662
926
|
constexpr int rows_per_warp = 2 * granularity;
|
|
663
927
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
664
928
|
|
|
665
|
-
y += (threadIdx.y % ntx) * (
|
|
929
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
666
930
|
|
|
667
931
|
const int * x_qs = (const int *) x;
|
|
668
|
-
const float * x_df = (const float *) x_qs + 2*
|
|
932
|
+
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
|
|
669
933
|
const int * y_qs = (const int *) y + 4;
|
|
670
934
|
const float * y_df = (const float *) y;
|
|
671
935
|
const half2 * y_ds = (const half2 *) y;
|
|
672
936
|
|
|
673
|
-
tile_A A[ntx][
|
|
674
|
-
float dA[ntx][tile_C::ne/2][
|
|
937
|
+
tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
|
|
938
|
+
float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
|
|
675
939
|
|
|
676
940
|
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
|
677
941
|
|
|
678
942
|
#pragma unroll
|
|
679
943
|
for (int n = 0; n < ntx; ++n) {
|
|
680
944
|
#pragma unroll
|
|
681
|
-
for (int k01 = 0; k01 <
|
|
945
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
682
946
|
const int k0 = k00 + k01;
|
|
683
947
|
|
|
684
948
|
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
|
@@ -689,7 +953,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
689
953
|
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
|
690
954
|
|
|
691
955
|
#pragma unroll
|
|
692
|
-
for (int k01 = 0; k01 <
|
|
956
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
693
957
|
const int k0 = k00 + k01;
|
|
694
958
|
|
|
695
959
|
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
|
@@ -700,7 +964,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
700
964
|
#pragma unroll
|
|
701
965
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
702
966
|
#pragma unroll
|
|
703
|
-
for (int k01 = 0; k01 <
|
|
967
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
704
968
|
tile_B B;
|
|
705
969
|
float dB[tile_C::ne/2];
|
|
706
970
|
|
|
@@ -729,11 +993,86 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
729
993
|
}
|
|
730
994
|
}
|
|
731
995
|
}
|
|
996
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
732
997
|
}
|
|
733
998
|
|
|
734
|
-
template <int mmq_x, int mmq_y
|
|
999
|
+
template <int mmq_x, int mmq_y>
|
|
1000
|
+
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
|
1001
|
+
const int * __restrict__ y,
|
|
1002
|
+
float * __restrict__ sum,
|
|
1003
|
+
const int k00) {
|
|
1004
|
+
typedef tile<16, 8, int> tile_A;
|
|
1005
|
+
typedef tile<8, 8, int> tile_B;
|
|
1006
|
+
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
|
1007
|
+
|
|
1008
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1009
|
+
constexpr int rows_per_warp = 2 * granularity;
|
|
1010
|
+
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
|
1011
|
+
|
|
1012
|
+
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
|
1013
|
+
|
|
1014
|
+
// Match layout from load_tiles_mxfp4_fp4
|
|
1015
|
+
const int * x_qs = (const int *) x;
|
|
1016
|
+
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
1017
|
+
const int * y_qs = (const int *) y + 4;
|
|
1018
|
+
const uint32_t * y_sc = (const uint32_t *) y;
|
|
1019
|
+
|
|
1020
|
+
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
|
1021
|
+
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
1022
|
+
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
1023
|
+
|
|
1024
|
+
// Block scale
|
|
1025
|
+
// Each thread has to point to a 4 byte scale value
|
|
1026
|
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
|
1027
|
+
|
|
1028
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1029
|
+
|
|
1030
|
+
#pragma unroll
|
|
1031
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1032
|
+
#pragma unroll
|
|
1033
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
1034
|
+
const int k0 = k00 + k01;
|
|
1035
|
+
|
|
1036
|
+
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
|
1037
|
+
MMQ_MMA_TILE_X_K_FP4);
|
|
1038
|
+
|
|
1039
|
+
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
|
1040
|
+
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
|
1041
|
+
scaleA[n][k01 / (2 * QI_MXFP4)] =
|
|
1042
|
+
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
|
1043
|
+
}
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
#pragma unroll
|
|
1047
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
|
1048
|
+
#pragma unroll
|
|
1049
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
1050
|
+
tile_B B;
|
|
1051
|
+
uint32_t scaleB; // 2xN scales
|
|
1052
|
+
|
|
1053
|
+
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
|
1054
|
+
|
|
1055
|
+
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
|
1056
|
+
|
|
1057
|
+
#pragma unroll
|
|
1058
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1059
|
+
tile_C C;
|
|
1060
|
+
|
|
1061
|
+
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
|
1062
|
+
#pragma unroll
|
|
1063
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1064
|
+
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
|
1065
|
+
}
|
|
1066
|
+
}
|
|
1067
|
+
}
|
|
1068
|
+
}
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
template <int mmq_x, int mmq_y>
|
|
735
1072
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
736
1073
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1074
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1075
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
737
1076
|
|
|
738
1077
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
|
739
1078
|
const int * x_qs = (const int *) x;
|
|
@@ -742,7 +1081,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
|
742
1081
|
const half2 * y_ds = (const half2 *) y;
|
|
743
1082
|
|
|
744
1083
|
// #pragma unroll
|
|
745
|
-
for (int k01 = 0; k01 <
|
|
1084
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
|
|
746
1085
|
const int k0 = k00 + k01;
|
|
747
1086
|
|
|
748
1087
|
#pragma unroll
|
|
@@ -750,45 +1089,96 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
|
750
1089
|
const int j = j0 + threadIdx.y;
|
|
751
1090
|
|
|
752
1091
|
#pragma unroll
|
|
753
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1092
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
754
1093
|
const int i = i0 + threadIdx.x;
|
|
755
1094
|
|
|
756
|
-
sum[j0/nwarps*mmq_y/
|
|
757
|
-
(&x_qs[i*(2*
|
|
758
|
-
x_dm[i*(
|
|
1095
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
|
1096
|
+
(&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
|
1097
|
+
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
759
1098
|
}
|
|
760
1099
|
}
|
|
761
1100
|
}
|
|
762
1101
|
}
|
|
763
1102
|
|
|
764
|
-
template <int mmq_x, int mmq_y
|
|
1103
|
+
template <int mmq_x, int mmq_y>
|
|
765
1104
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
766
1105
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1106
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1107
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1108
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1109
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1110
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
767
1111
|
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
1112
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1113
|
+
constexpr int rows_per_warp = granularity;
|
|
1114
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1115
|
+
|
|
1116
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1117
|
+
|
|
1118
|
+
const int * x_qs = (const int *) x;
|
|
1119
|
+
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
|
|
1120
|
+
const int * y_qs = (const int *) y + 4;
|
|
1121
|
+
const half2 * y_dm = (const half2 *) y;
|
|
1122
|
+
|
|
1123
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1124
|
+
|
|
1125
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
1126
|
+
const int k0 = k00 + k01;
|
|
1127
|
+
|
|
1128
|
+
tile_A A[ntx];
|
|
1129
|
+
#pragma unroll
|
|
1130
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1131
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
#pragma unroll
|
|
1135
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1136
|
+
tile_B B;
|
|
1137
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1138
|
+
|
|
1139
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1140
|
+
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1141
|
+
|
|
1142
|
+
#pragma unroll
|
|
1143
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1144
|
+
tile_C C;
|
|
1145
|
+
mma(C, A[n], B);
|
|
1146
|
+
|
|
1147
|
+
#pragma unroll
|
|
1148
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1149
|
+
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
|
|
1150
|
+
float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
|
|
1151
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
|
|
1152
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
|
|
1153
|
+
}
|
|
1154
|
+
}
|
|
1155
|
+
}
|
|
1156
|
+
}
|
|
1157
|
+
#else
|
|
1158
|
+
typedef tile<16, 8, int> tile_A;
|
|
1159
|
+
typedef tile< 8, 8, int> tile_B;
|
|
1160
|
+
typedef tile<16, 8, int> tile_C;
|
|
771
1161
|
|
|
772
1162
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
773
1163
|
constexpr int rows_per_warp = 2 * granularity;
|
|
774
1164
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
775
1165
|
|
|
776
|
-
y += (threadIdx.y % ntx) * (
|
|
1166
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
777
1167
|
|
|
778
1168
|
const int * x_qs = (const int *) x;
|
|
779
|
-
const half2 * x_dm = (const half2 *) x_qs + 2*
|
|
1169
|
+
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
|
|
780
1170
|
const int * y_qs = (const int *) y + 4;
|
|
781
1171
|
const half2 * y_dm = (const half2 *) y;
|
|
782
1172
|
|
|
783
|
-
tile_A A[ntx][
|
|
784
|
-
float2 dmA[ntx][tile_C::ne/2][
|
|
1173
|
+
tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
|
|
1174
|
+
float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
|
|
785
1175
|
|
|
786
1176
|
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
|
787
1177
|
|
|
788
1178
|
#pragma unroll
|
|
789
1179
|
for (int n = 0; n < ntx; ++n) {
|
|
790
1180
|
#pragma unroll
|
|
791
|
-
for (int k01 = 0; k01 <
|
|
1181
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
792
1182
|
const int k0 = k00 + k01;
|
|
793
1183
|
|
|
794
1184
|
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
|
@@ -799,7 +1189,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
799
1189
|
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
|
800
1190
|
|
|
801
1191
|
#pragma unroll
|
|
802
|
-
for (int k01 = 0; k01 <
|
|
1192
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
803
1193
|
const int k0 = k00 + k01;
|
|
804
1194
|
|
|
805
1195
|
dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
|
|
@@ -810,7 +1200,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
810
1200
|
#pragma unroll
|
|
811
1201
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
812
1202
|
#pragma unroll
|
|
813
|
-
for (int k01 = 0; k01 <
|
|
1203
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
814
1204
|
tile_B B;
|
|
815
1205
|
float2 dsB[tile_C::ne/2];
|
|
816
1206
|
|
|
@@ -836,11 +1226,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
836
1226
|
}
|
|
837
1227
|
}
|
|
838
1228
|
}
|
|
1229
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
839
1230
|
}
|
|
840
1231
|
|
|
841
|
-
|
|
1232
|
+
// Used for Q3_K, IQ2_S, and IQ2_XS
|
|
1233
|
+
template <int mmq_x, int mmq_y>
|
|
842
1234
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
843
1235
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1236
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1237
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
844
1238
|
|
|
845
1239
|
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
|
846
1240
|
const int * x_qs = (const int *) x;
|
|
@@ -849,7 +1243,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
|
849
1243
|
const float * y_df = (const float *) y;
|
|
850
1244
|
|
|
851
1245
|
// #pragma unroll
|
|
852
|
-
for (int k01 = 0; k01 <
|
|
1246
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
853
1247
|
const int k0 = k00 + k01;
|
|
854
1248
|
|
|
855
1249
|
#pragma unroll
|
|
@@ -857,23 +1251,123 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
|
857
1251
|
const int j = j0 + threadIdx.y;
|
|
858
1252
|
|
|
859
1253
|
#pragma unroll
|
|
860
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1254
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
861
1255
|
const int i = i0 + threadIdx.x;
|
|
862
1256
|
|
|
863
|
-
sum[j0/nwarps*mmq_y/
|
|
864
|
-
&x_qs[i*(2*
|
|
1257
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
|
|
1258
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
|
|
865
1259
|
&y_qs[j*MMQ_TILE_Y_K + k01],
|
|
866
|
-
&x_df[i*(2*
|
|
1260
|
+
&x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
|
|
867
1261
|
y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
868
1262
|
}
|
|
869
1263
|
}
|
|
870
1264
|
}
|
|
871
1265
|
}
|
|
872
1266
|
|
|
873
|
-
|
|
1267
|
+
// Used for Q3_K, IQ2_S, and IQ2_XS:
|
|
1268
|
+
template <int mmq_x, int mmq_y>
|
|
874
1269
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
875
1270
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
876
|
-
#
|
|
1271
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1272
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1273
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1274
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1275
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1276
|
+
typedef tile<64, 2, int, input_layout> tile_load;
|
|
1277
|
+
|
|
1278
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1279
|
+
constexpr int rows_per_warp = granularity;
|
|
1280
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1281
|
+
|
|
1282
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1283
|
+
|
|
1284
|
+
const int * x_qs = (const int *) x;
|
|
1285
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
1286
|
+
const int * y_qs = (const int *) y + 4;
|
|
1287
|
+
const float * y_df = (const float *) y;
|
|
1288
|
+
|
|
1289
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1290
|
+
|
|
1291
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1292
|
+
const int k0 = k00 + k01;
|
|
1293
|
+
|
|
1294
|
+
tile_A A[ntx];
|
|
1295
|
+
#pragma unroll
|
|
1296
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1297
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
1298
|
+
}
|
|
1299
|
+
|
|
1300
|
+
#pragma unroll
|
|
1301
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1302
|
+
tile_B B[1];
|
|
1303
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1304
|
+
|
|
1305
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1306
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
1307
|
+
|
|
1308
|
+
#pragma unroll
|
|
1309
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1310
|
+
tile_C C;
|
|
1311
|
+
mma(C, A[n], B[0]);
|
|
1312
|
+
|
|
1313
|
+
#pragma unroll
|
|
1314
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1315
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1316
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
|
|
1317
|
+
}
|
|
1318
|
+
}
|
|
1319
|
+
}
|
|
1320
|
+
}
|
|
1321
|
+
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
1322
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1323
|
+
typedef tile<16, 4, int, input_layout> tile_A;
|
|
1324
|
+
typedef tile<16, 4, int, input_layout> tile_B;
|
|
1325
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1326
|
+
|
|
1327
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1328
|
+
constexpr int rows_per_warp = granularity;
|
|
1329
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1330
|
+
|
|
1331
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1332
|
+
|
|
1333
|
+
const int * x_qs = (const int *) x;
|
|
1334
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
1335
|
+
const int * y_qs = (const int *) y + 4;
|
|
1336
|
+
const float * y_df = (const float *) y;
|
|
1337
|
+
|
|
1338
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1339
|
+
|
|
1340
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1341
|
+
const int k0 = k00 + k01;
|
|
1342
|
+
|
|
1343
|
+
tile_A A[ntx];
|
|
1344
|
+
#pragma unroll
|
|
1345
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1346
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
#pragma unroll
|
|
1350
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1351
|
+
tile_B B;
|
|
1352
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1353
|
+
|
|
1354
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1355
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
1356
|
+
|
|
1357
|
+
#pragma unroll
|
|
1358
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1359
|
+
tile_C C;
|
|
1360
|
+
mma(C, A[n], B);
|
|
1361
|
+
|
|
1362
|
+
#pragma unroll
|
|
1363
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1364
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1365
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
|
|
1366
|
+
}
|
|
1367
|
+
}
|
|
1368
|
+
}
|
|
1369
|
+
}
|
|
1370
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
|
877
1371
|
|
|
878
1372
|
typedef tile<16, 4, int> tile_A;
|
|
879
1373
|
typedef tile<16, 8, int> tile_A_8;
|
|
@@ -884,10 +1378,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
884
1378
|
constexpr int rows_per_warp = 2 * granularity;
|
|
885
1379
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
886
1380
|
|
|
887
|
-
y += (threadIdx.y % ntx) * (
|
|
1381
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
888
1382
|
|
|
889
1383
|
const int * x_qs = (const int *) x;
|
|
890
|
-
const float * x_df = (const float *) x_qs +
|
|
1384
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
891
1385
|
const int * y_qs = (const int *) y + 4;
|
|
892
1386
|
const float * y_df = (const float *) y;
|
|
893
1387
|
|
|
@@ -899,7 +1393,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
899
1393
|
#pragma unroll
|
|
900
1394
|
for (int n = 0; n < ntx; ++n) {
|
|
901
1395
|
#pragma unroll
|
|
902
|
-
for (int k01 = 0; k01 <
|
|
1396
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
|
903
1397
|
const int k0 = k00 + k01;
|
|
904
1398
|
|
|
905
1399
|
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
@@ -910,7 +1404,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
910
1404
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
|
911
1405
|
|
|
912
1406
|
#pragma unroll
|
|
913
|
-
for (int k01 = 0; k01 <
|
|
1407
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
914
1408
|
const int k0 = k00 + k01;
|
|
915
1409
|
|
|
916
1410
|
dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
|
|
@@ -921,7 +1415,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
921
1415
|
#pragma unroll
|
|
922
1416
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
923
1417
|
#pragma unroll
|
|
924
|
-
for (int k01 = 0; k01 <
|
|
1418
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
|
925
1419
|
tile_B B[2];
|
|
926
1420
|
float dB[tile_C::ne/2];
|
|
927
1421
|
|
|
@@ -950,28 +1444,31 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
950
1444
|
}
|
|
951
1445
|
}
|
|
952
1446
|
#else
|
|
953
|
-
|
|
1447
|
+
GGML_UNUSED_VARS(x, y, sum, k00);
|
|
954
1448
|
NO_DEVICE_CODE;
|
|
955
|
-
#endif //
|
|
1449
|
+
#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
|
|
956
1450
|
}
|
|
957
1451
|
|
|
958
|
-
template <int mmq_y,
|
|
1452
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
|
959
1453
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1454
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
960
1455
|
|
|
961
|
-
#
|
|
1456
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
962
1457
|
int * x_qs = (int *) x_tile;
|
|
963
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
|
1458
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
964
1459
|
#else
|
|
965
1460
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
|
966
1461
|
int * x_qs = (int *) x_tile;
|
|
967
1462
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
968
|
-
#endif //
|
|
1463
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
969
1464
|
|
|
970
|
-
|
|
1465
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
|
|
1466
|
+
constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
|
|
1467
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
971
1468
|
|
|
972
1469
|
#pragma unroll
|
|
973
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
974
|
-
int i = i0 + threadIdx.y*
|
|
1470
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
1471
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
975
1472
|
|
|
976
1473
|
if (need_check) {
|
|
977
1474
|
i = min(i, i_max);
|
|
@@ -987,11 +1484,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
987
1484
|
|
|
988
1485
|
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
|
989
1486
|
|
|
990
|
-
#
|
|
1487
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
991
1488
|
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
|
992
1489
|
#else
|
|
993
|
-
x_qs[i*(2*
|
|
994
|
-
#endif //
|
|
1490
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
|
1491
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
995
1492
|
}
|
|
996
1493
|
|
|
997
1494
|
const int sc_m = bxi->scales[kqsx];
|
|
@@ -1002,17 +1499,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1002
1499
|
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
|
1003
1500
|
#endif // FAST_FP16_AVAILABLE
|
|
1004
1501
|
|
|
1005
|
-
#
|
|
1502
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1006
1503
|
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
|
1007
1504
|
#else
|
|
1008
|
-
x_dm[i*(
|
|
1009
|
-
#endif //
|
|
1505
|
+
x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
|
|
1506
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1010
1507
|
}
|
|
1011
1508
|
}
|
|
1012
1509
|
|
|
1013
|
-
template <int mmq_x, int mmq_y
|
|
1510
|
+
template <int mmq_x, int mmq_y>
|
|
1014
1511
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
1015
1512
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1513
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1514
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1016
1515
|
|
|
1017
1516
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
|
1018
1517
|
const int * x_qs = (const int *) x;
|
|
@@ -1029,7 +1528,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1029
1528
|
}
|
|
1030
1529
|
|
|
1031
1530
|
#pragma unroll
|
|
1032
|
-
for (int k01 = 0; k01 <
|
|
1531
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
|
1033
1532
|
const int k0 = k00 + k01;
|
|
1034
1533
|
|
|
1035
1534
|
#pragma unroll
|
|
@@ -1037,13 +1536,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1037
1536
|
const int j = j0 + threadIdx.y;
|
|
1038
1537
|
|
|
1039
1538
|
#pragma unroll
|
|
1040
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1539
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1041
1540
|
const int i = i0 + threadIdx.x;
|
|
1042
1541
|
|
|
1043
1542
|
constexpr int ns = 2;
|
|
1044
|
-
sum[j0/nwarps*mmq_y/
|
|
1045
|
-
&x_qs[i*(2*
|
|
1046
|
-
&x_dm[i*(
|
|
1543
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
|
1544
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
|
1545
|
+
&x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
|
1047
1546
|
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
|
1048
1547
|
}
|
|
1049
1548
|
}
|
|
@@ -1052,7 +1551,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1052
1551
|
// Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
|
|
1053
1552
|
// As a workaround 2 separate loops are used instead.
|
|
1054
1553
|
#pragma unroll
|
|
1055
|
-
for (int k01 =
|
|
1554
|
+
for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
|
1056
1555
|
const int k0 = k00 + k01;
|
|
1057
1556
|
|
|
1058
1557
|
#pragma unroll
|
|
@@ -1060,23 +1559,158 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1060
1559
|
const int j = j0 + threadIdx.y;
|
|
1061
1560
|
|
|
1062
1561
|
#pragma unroll
|
|
1063
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1562
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1064
1563
|
const int i = i0 + threadIdx.x;
|
|
1065
1564
|
|
|
1066
1565
|
constexpr int ns = 1;
|
|
1067
|
-
sum[j0/nwarps*mmq_y/
|
|
1068
|
-
&x_qs[i*(2*
|
|
1069
|
-
&x_dm[i*(
|
|
1566
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
|
1567
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
|
1568
|
+
&x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
|
1070
1569
|
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
|
1071
1570
|
}
|
|
1072
1571
|
}
|
|
1073
1572
|
}
|
|
1074
1573
|
}
|
|
1075
1574
|
|
|
1076
|
-
template <int mmq_x, int mmq_y
|
|
1575
|
+
template <int mmq_x, int mmq_y>
|
|
1077
1576
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1078
1577
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1079
|
-
#
|
|
1578
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1579
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1580
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1581
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1582
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1583
|
+
typedef tile<64, 2, int, input_layout> tile_load;
|
|
1584
|
+
|
|
1585
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1586
|
+
constexpr int rows_per_warp = granularity;
|
|
1587
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1588
|
+
|
|
1589
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1590
|
+
|
|
1591
|
+
const int * x_qs = (const int *) x;
|
|
1592
|
+
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
1593
|
+
const int * y_qs = (const int *) y + 4;
|
|
1594
|
+
const half2 * y_ds = (const half2 *) y;
|
|
1595
|
+
|
|
1596
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1597
|
+
|
|
1598
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1599
|
+
const int k0 = k00 + k01;
|
|
1600
|
+
|
|
1601
|
+
tile_A A[ntx];
|
|
1602
|
+
#pragma unroll
|
|
1603
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1604
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
1605
|
+
}
|
|
1606
|
+
|
|
1607
|
+
#pragma unroll
|
|
1608
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1609
|
+
tile_B B[1];
|
|
1610
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1611
|
+
|
|
1612
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1613
|
+
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
|
|
1614
|
+
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
|
|
1615
|
+
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
|
|
1616
|
+
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
|
|
1617
|
+
|
|
1618
|
+
tile_C Cm;
|
|
1619
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1620
|
+
tile_A A1;
|
|
1621
|
+
A1.x[0] = 0x01010101;
|
|
1622
|
+
A1.x[1] = 0x01010101;
|
|
1623
|
+
mma(Cm, A1, B[0]);
|
|
1624
|
+
}
|
|
1625
|
+
|
|
1626
|
+
#pragma unroll
|
|
1627
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1628
|
+
tile_C Cd;
|
|
1629
|
+
mma(Cd, A[n], B[0]);
|
|
1630
|
+
|
|
1631
|
+
#pragma unroll
|
|
1632
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1633
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1634
|
+
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
|
|
1635
|
+
float tmp = Cd.x[l]*dm.x;
|
|
1636
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1637
|
+
tmp -= Cm.x[l]*dm.y;
|
|
1638
|
+
}
|
|
1639
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
|
|
1640
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
|
|
1641
|
+
}
|
|
1642
|
+
}
|
|
1643
|
+
}
|
|
1644
|
+
}
|
|
1645
|
+
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
1646
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1647
|
+
typedef tile<16, 4, int, input_layout> tile_A;
|
|
1648
|
+
typedef tile<16, 4, int, input_layout> tile_B;
|
|
1649
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1650
|
+
|
|
1651
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1652
|
+
constexpr int rows_per_warp = granularity;
|
|
1653
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1654
|
+
|
|
1655
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1656
|
+
|
|
1657
|
+
const int * x_qs = (const int *) x;
|
|
1658
|
+
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
1659
|
+
const int * y_qs = (const int *) y + 4;
|
|
1660
|
+
const half2 * y_ds = (const half2 *) y;
|
|
1661
|
+
|
|
1662
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1663
|
+
|
|
1664
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1665
|
+
const int k0 = k00 + k01;
|
|
1666
|
+
|
|
1667
|
+
tile_A A[ntx];
|
|
1668
|
+
#pragma unroll
|
|
1669
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1670
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
1671
|
+
}
|
|
1672
|
+
|
|
1673
|
+
#pragma unroll
|
|
1674
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1675
|
+
tile_B B;
|
|
1676
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1677
|
+
|
|
1678
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1679
|
+
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
|
|
1680
|
+
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
|
|
1681
|
+
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
|
|
1682
|
+
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
|
|
1683
|
+
|
|
1684
|
+
tile_C Cm;
|
|
1685
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1686
|
+
tile_A A1;
|
|
1687
|
+
#pragma unroll
|
|
1688
|
+
for (int l = 0; l < tile_A::ne; ++l) {
|
|
1689
|
+
A1.x[l] = 0x01010101;
|
|
1690
|
+
}
|
|
1691
|
+
mma(Cm, A1, B);
|
|
1692
|
+
}
|
|
1693
|
+
|
|
1694
|
+
#pragma unroll
|
|
1695
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1696
|
+
tile_C Cd;
|
|
1697
|
+
mma(Cd, A[n], B);
|
|
1698
|
+
|
|
1699
|
+
#pragma unroll
|
|
1700
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1701
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1702
|
+
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
|
|
1703
|
+
float tmp = Cd.x[l]*dm.x;
|
|
1704
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1705
|
+
tmp -= Cm.x[l]*dm.y;
|
|
1706
|
+
}
|
|
1707
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
|
|
1708
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
|
|
1709
|
+
}
|
|
1710
|
+
}
|
|
1711
|
+
}
|
|
1712
|
+
}
|
|
1713
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
|
1080
1714
|
|
|
1081
1715
|
typedef tile<16, 4, int> tile_A;
|
|
1082
1716
|
typedef tile<16, 8, int> tile_A_8;
|
|
@@ -1087,10 +1721,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1087
1721
|
constexpr int rows_per_warp = 2 * granularity;
|
|
1088
1722
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1089
1723
|
|
|
1090
|
-
y += (threadIdx.y % ntx) * (
|
|
1724
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1091
1725
|
|
|
1092
1726
|
const int * x_qs = (const int *) x;
|
|
1093
|
-
const half2 * x_dm = (const half2 *) x_qs +
|
|
1727
|
+
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
1094
1728
|
const int * y_qs = (const int *) y + 4;
|
|
1095
1729
|
const half2 * y_ds = (const half2 *) y;
|
|
1096
1730
|
|
|
@@ -1103,7 +1737,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1103
1737
|
#pragma unroll
|
|
1104
1738
|
for (int n = 0; n < ntx; ++n) {
|
|
1105
1739
|
#pragma unroll
|
|
1106
|
-
for (int k01 = 0; k01 <
|
|
1740
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
1107
1741
|
const int k0 = k00 + k01;
|
|
1108
1742
|
|
|
1109
1743
|
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
@@ -1117,7 +1751,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1117
1751
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
|
1118
1752
|
|
|
1119
1753
|
#pragma unroll
|
|
1120
|
-
for (int k01 = 0; k01 <
|
|
1754
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
|
|
1121
1755
|
const int k0 = k00 + k01;
|
|
1122
1756
|
|
|
1123
1757
|
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
|
|
@@ -1140,7 +1774,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1140
1774
|
}
|
|
1141
1775
|
|
|
1142
1776
|
#pragma unroll
|
|
1143
|
-
for (int k01 = 0; k01 <
|
|
1777
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
1144
1778
|
tile_B B[2];
|
|
1145
1779
|
|
|
1146
1780
|
// Here load_generic is faster than load_ldmatrix.
|
|
@@ -1148,7 +1782,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1148
1782
|
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
|
1149
1783
|
|
|
1150
1784
|
tile_C Cm[2];
|
|
1151
|
-
if (k01 >=
|
|
1785
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1152
1786
|
tile_A A1;
|
|
1153
1787
|
A1.x[0] = 0x01010101;
|
|
1154
1788
|
A1.x[1] = 0x01010101;
|
|
@@ -1166,16 +1800,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1166
1800
|
#pragma unroll
|
|
1167
1801
|
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1168
1802
|
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
|
1169
|
-
if (k01 >=
|
|
1803
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1170
1804
|
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
|
1171
1805
|
}
|
|
1172
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 <
|
|
1806
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
|
|
1173
1807
|
}
|
|
1174
1808
|
}
|
|
1175
1809
|
}
|
|
1176
1810
|
|
|
1177
1811
|
#pragma unroll
|
|
1178
|
-
for (int k01 = 0; k01 <
|
|
1812
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
|
|
1179
1813
|
float2 sB[tile_C::ne/2];
|
|
1180
1814
|
|
|
1181
1815
|
#pragma unroll
|
|
@@ -1196,29 +1830,33 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1196
1830
|
}
|
|
1197
1831
|
}
|
|
1198
1832
|
#else
|
|
1199
|
-
|
|
1833
|
+
GGML_UNUSED_VARS(x, y, sum, k00);
|
|
1200
1834
|
NO_DEVICE_CODE;
|
|
1201
|
-
#endif //
|
|
1835
|
+
#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
|
|
1202
1836
|
}
|
|
1203
1837
|
|
|
1204
|
-
template <int mmq_y,
|
|
1838
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
|
1205
1839
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1840
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1841
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1206
1842
|
|
|
1207
|
-
#
|
|
1843
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1208
1844
|
int * x_qs = (int *) x_tile;
|
|
1209
|
-
float * x_df = (float *) (x_qs +
|
|
1845
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1210
1846
|
#else
|
|
1211
1847
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
|
1212
1848
|
int * x_qs = (int *) x_tile;
|
|
1213
1849
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1214
1850
|
int * x_sc = (int *) (x_df + txs.dm);
|
|
1215
|
-
#endif //
|
|
1851
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1216
1852
|
|
|
1217
|
-
|
|
1853
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
|
|
1854
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
1855
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
1218
1856
|
|
|
1219
1857
|
#pragma unroll
|
|
1220
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
1221
|
-
int i = i0 + threadIdx.y
|
|
1858
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
1859
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
1222
1860
|
|
|
1223
1861
|
if (need_check) {
|
|
1224
1862
|
i = min(i, i_max);
|
|
@@ -1238,17 +1876,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1238
1876
|
|
|
1239
1877
|
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
|
1240
1878
|
|
|
1241
|
-
#
|
|
1879
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1242
1880
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
|
1243
1881
|
#else
|
|
1244
|
-
x_qs[i*(2*
|
|
1245
|
-
#endif //
|
|
1882
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
|
1883
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1246
1884
|
}
|
|
1247
1885
|
}
|
|
1248
1886
|
|
|
1887
|
+
constexpr int rows_per_warp = warp_size / 4;
|
|
1249
1888
|
#pragma unroll
|
|
1250
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1251
|
-
int i = i0 + threadIdx.y*
|
|
1889
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
1890
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
|
|
1252
1891
|
|
|
1253
1892
|
if (need_check) {
|
|
1254
1893
|
i = min(i, i_max);
|
|
@@ -1256,7 +1895,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1256
1895
|
|
|
1257
1896
|
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
|
|
1258
1897
|
|
|
1259
|
-
const int ksc = threadIdx.x %
|
|
1898
|
+
const int ksc = threadIdx.x % 4;
|
|
1260
1899
|
|
|
1261
1900
|
const int ksc_low = ksc % (QI3_K/8);
|
|
1262
1901
|
const int shift_low = 4 * (ksc / (QI3_K/8));
|
|
@@ -1268,23 +1907,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1268
1907
|
|
|
1269
1908
|
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
|
1270
1909
|
|
|
1271
|
-
#
|
|
1910
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1272
1911
|
const int8_t * sc8 = (const int8_t *) ≻
|
|
1273
1912
|
const float d = bxi->d;
|
|
1274
1913
|
|
|
1275
1914
|
#pragma unroll
|
|
1276
1915
|
for (int l = 0; l < int(sizeof(int)); ++l) {
|
|
1277
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*
|
|
1916
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
|
|
1278
1917
|
}
|
|
1279
1918
|
#else
|
|
1280
|
-
x_sc[i*(
|
|
1281
|
-
#endif //
|
|
1919
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
|
|
1920
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1282
1921
|
}
|
|
1283
1922
|
|
|
1284
|
-
#
|
|
1923
|
+
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
|
|
1285
1924
|
#pragma unroll
|
|
1286
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1287
|
-
int i = (i0 + threadIdx.y*
|
|
1925
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
1926
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
1288
1927
|
|
|
1289
1928
|
if (need_check) {
|
|
1290
1929
|
i = min(i, i_max);
|
|
@@ -1294,12 +1933,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1294
1933
|
|
|
1295
1934
|
x_df[i] = bxi->d;
|
|
1296
1935
|
}
|
|
1297
|
-
#endif //
|
|
1936
|
+
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
|
|
1298
1937
|
}
|
|
1299
1938
|
|
|
1300
|
-
template <int mmq_x, int mmq_y
|
|
1939
|
+
template <int mmq_x, int mmq_y>
|
|
1301
1940
|
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
1302
1941
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1942
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1943
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1303
1944
|
|
|
1304
1945
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
|
1305
1946
|
const int * x_qs = (const int *) x;
|
|
@@ -1309,7 +1950,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
|
1309
1950
|
const float * y_df = (const float *) y;
|
|
1310
1951
|
|
|
1311
1952
|
// #pragma unroll
|
|
1312
|
-
for (int k01 = 0; k01 <
|
|
1953
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
|
1313
1954
|
const int k0 = k00 + k01;
|
|
1314
1955
|
|
|
1315
1956
|
#pragma unroll
|
|
@@ -1317,13 +1958,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
|
1317
1958
|
const int j = j0 + threadIdx.y;
|
|
1318
1959
|
|
|
1319
1960
|
#pragma unroll
|
|
1320
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1961
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1321
1962
|
const int i = i0 + threadIdx.x;
|
|
1322
1963
|
|
|
1323
|
-
const int8_t * scales = ((const int8_t *) (x_sc + i*(
|
|
1964
|
+
const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
|
|
1324
1965
|
|
|
1325
|
-
sum[j0/nwarps*mmq_y/
|
|
1326
|
-
&x_qs[i*(2*
|
|
1966
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
|
|
1967
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
|
|
1327
1968
|
x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1328
1969
|
}
|
|
1329
1970
|
}
|
|
@@ -1340,72 +1981,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
|
|
|
1340
1981
|
((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
|
|
1341
1982
|
}
|
|
1342
1983
|
|
|
1343
|
-
template <int mmq_y,
|
|
1984
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
|
1344
1985
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1986
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1987
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1345
1988
|
|
|
1346
|
-
#
|
|
1989
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1347
1990
|
int * x_qs = (int *) x_tile;
|
|
1348
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
|
1991
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
1349
1992
|
#else
|
|
1350
1993
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
|
1351
1994
|
int * x_qs = (int *) x_tile;
|
|
1352
1995
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
1353
1996
|
int * x_sc = (int *) (x_dm + txs.dm);
|
|
1354
|
-
#endif //
|
|
1997
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1998
|
+
|
|
1999
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
|
|
2000
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2001
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
1355
2002
|
|
|
1356
2003
|
#pragma unroll
|
|
1357
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
1358
|
-
int i = i0 + threadIdx.y;
|
|
2004
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
2005
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
1359
2006
|
|
|
1360
2007
|
if (need_check) {
|
|
1361
2008
|
i = min(i, i_max);
|
|
1362
2009
|
}
|
|
1363
2010
|
|
|
1364
2011
|
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
|
1365
|
-
const int qs0 = get_int_b4(bxi->qs,
|
|
2012
|
+
const int qs0 = get_int_b4(bxi->qs, txi);
|
|
1366
2013
|
|
|
1367
|
-
#
|
|
1368
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(
|
|
1369
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(
|
|
2014
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2015
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
|
2016
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
|
|
1370
2017
|
#else
|
|
1371
|
-
x_qs[i*(
|
|
1372
|
-
#endif //
|
|
2018
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
|
2019
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1373
2020
|
}
|
|
1374
2021
|
|
|
1375
|
-
#
|
|
1376
|
-
|
|
2022
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2023
|
+
constexpr int rows_per_warp = warp_size / 2;
|
|
1377
2024
|
#pragma unroll
|
|
1378
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
2025
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
2026
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2027
|
+
// Need if on AMD instead of % because warp_size == 64
|
|
2028
|
+
// This causes double work and throughput loss (MI300X)
|
|
2029
|
+
// H100 loses about 100 t/s with 'if' condition over '%'
|
|
2030
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
|
|
2031
|
+
if (i < mmq_y) {
|
|
2032
|
+
#else
|
|
2033
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
|
|
2034
|
+
{
|
|
2035
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2036
|
+
if (need_check) {
|
|
2037
|
+
i = min(i, i_max);
|
|
2038
|
+
}
|
|
1384
2039
|
|
|
1385
|
-
|
|
2040
|
+
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
|
1386
2041
|
|
|
1387
|
-
|
|
1388
|
-
|
|
2042
|
+
const int * scales = (const int *) bxi->scales;
|
|
2043
|
+
const int ksc = threadIdx.x % 2;
|
|
1389
2044
|
|
|
1390
|
-
|
|
1391
|
-
|
|
2045
|
+
const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
|
|
2046
|
+
const int m32 = unpack_scales_q45_K(scales, ksc + 2);
|
|
1392
2047
|
|
|
1393
|
-
|
|
1394
|
-
|
|
2048
|
+
const uint8_t * sc8 = (const uint8_t *) &sc32;
|
|
2049
|
+
const uint8_t * m8 = (const uint8_t *) &m32;
|
|
1395
2050
|
|
|
1396
|
-
|
|
2051
|
+
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
|
|
1397
2052
|
|
|
1398
|
-
#pragma unroll
|
|
1399
|
-
|
|
1400
|
-
|
|
2053
|
+
#pragma unroll
|
|
2054
|
+
for (int l = 0; l < sizeof(int); ++l) {
|
|
2055
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
|
|
2056
|
+
}
|
|
1401
2057
|
}
|
|
1402
2058
|
}
|
|
1403
|
-
|
|
1404
2059
|
#else
|
|
1405
|
-
|
|
1406
2060
|
#pragma unroll
|
|
1407
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1408
|
-
int i = (i0 + threadIdx.y*
|
|
2061
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
2062
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
1409
2063
|
|
|
1410
2064
|
if (need_check) {
|
|
1411
2065
|
i = min(i, i_max);
|
|
@@ -1415,30 +2069,32 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1415
2069
|
|
|
1416
2070
|
x_dm[i] = bxi->dm;
|
|
1417
2071
|
}
|
|
1418
|
-
|
|
2072
|
+
constexpr int rows_per_warp = warp_size / 4;
|
|
1419
2073
|
#pragma unroll
|
|
1420
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
1421
|
-
int i = (i0 + threadIdx.y
|
|
2074
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
2075
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
|
1422
2076
|
|
|
1423
2077
|
if (need_check) {
|
|
1424
2078
|
i = min(i, i_max);
|
|
1425
2079
|
}
|
|
1426
2080
|
|
|
1427
|
-
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (
|
|
2081
|
+
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
|
|
1428
2082
|
|
|
1429
2083
|
const int * scales = (const int *) bxi->scales;
|
|
1430
2084
|
|
|
1431
|
-
const int ksc = threadIdx.x % (
|
|
2085
|
+
const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
|
|
1432
2086
|
const int scales8 = unpack_scales_q45_K(scales, ksc);
|
|
1433
2087
|
|
|
1434
|
-
x_sc[i*(
|
|
2088
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
|
1435
2089
|
}
|
|
1436
|
-
#endif //
|
|
2090
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1437
2091
|
}
|
|
1438
2092
|
|
|
1439
|
-
template <int mmq_x, int mmq_y
|
|
2093
|
+
template <int mmq_x, int mmq_y>
|
|
1440
2094
|
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
1441
2095
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
2096
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2097
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1442
2098
|
|
|
1443
2099
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
|
1444
2100
|
const int * x_qs = (const int *) x;
|
|
@@ -1448,7 +2104,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
|
1448
2104
|
const half2 * y_ds = (const half2 *) y;
|
|
1449
2105
|
|
|
1450
2106
|
// #pragma unroll
|
|
1451
|
-
for (int k01 = 0; k01 <
|
|
2107
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
|
|
1452
2108
|
const int k0 = k00 + k01;
|
|
1453
2109
|
|
|
1454
2110
|
#pragma unroll
|
|
@@ -1456,97 +2112,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
|
1456
2112
|
const int j = j0 + threadIdx.y;
|
|
1457
2113
|
|
|
1458
2114
|
#pragma unroll
|
|
1459
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
2115
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1460
2116
|
const int i = i0 + threadIdx.x;
|
|
1461
2117
|
|
|
1462
|
-
const uint8_t * sc = (const uint8_t *) &x_sc[i * (
|
|
2118
|
+
const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
|
|
1463
2119
|
|
|
1464
|
-
sum[j0/nwarps*mmq_y/
|
|
1465
|
-
&x_qs[i*(
|
|
2120
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
|
|
2121
|
+
&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
|
|
1466
2122
|
x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1467
2123
|
}
|
|
1468
2124
|
}
|
|
1469
2125
|
}
|
|
1470
2126
|
}
|
|
1471
2127
|
|
|
1472
|
-
template <int mmq_y,
|
|
2128
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
|
1473
2129
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2130
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2131
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1474
2132
|
|
|
1475
|
-
#
|
|
2133
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1476
2134
|
int * x_qs = (int *) x_tile;
|
|
1477
|
-
half2 * x_dm = (half2 *) (x_qs +
|
|
2135
|
+
half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1478
2136
|
#else
|
|
1479
2137
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
|
1480
2138
|
int * x_qs = (int *) x_tile;
|
|
1481
2139
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
1482
2140
|
int * x_sc = (int *) (x_dm + txs.dm);
|
|
1483
|
-
#endif //
|
|
2141
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2142
|
+
|
|
2143
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
|
|
2144
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2145
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
1484
2146
|
|
|
1485
2147
|
#pragma unroll
|
|
1486
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
1487
|
-
int i = i0 + threadIdx.y;
|
|
2148
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
2149
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
1488
2150
|
|
|
1489
2151
|
if (need_check) {
|
|
1490
2152
|
i = min(i, i_max);
|
|
1491
2153
|
}
|
|
1492
2154
|
|
|
1493
2155
|
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
|
|
1494
|
-
const int ky = QR5_K*
|
|
2156
|
+
const int ky = QR5_K*txi;
|
|
1495
2157
|
|
|
1496
|
-
const int ql = get_int_b4(bxi->qs,
|
|
2158
|
+
const int ql = get_int_b4(bxi->qs, txi);
|
|
1497
2159
|
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
|
1498
2160
|
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
|
1499
2161
|
|
|
1500
|
-
const int qh = get_int_b4(bxi->qh,
|
|
1501
|
-
const int qh0 = ((qh >> (2 * (
|
|
1502
|
-
const int qh1 = ((qh >> (2 * (
|
|
2162
|
+
const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
|
|
2163
|
+
const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
|
|
2164
|
+
const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
|
|
1503
2165
|
|
|
1504
|
-
const int kq0 = ky - ky % (QI5_K/2) +
|
|
1505
|
-
const int kq1 = ky - ky % (QI5_K/2) +
|
|
2166
|
+
const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
|
|
2167
|
+
const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
|
|
1506
2168
|
|
|
1507
|
-
#
|
|
2169
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1508
2170
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
|
|
1509
2171
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
|
|
1510
2172
|
#else
|
|
1511
|
-
x_qs[i*(2*
|
|
1512
|
-
x_qs[i*(2*
|
|
1513
|
-
#endif //
|
|
2173
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
|
|
2174
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
|
|
2175
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1514
2176
|
}
|
|
1515
2177
|
|
|
1516
|
-
#
|
|
1517
|
-
|
|
2178
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2179
|
+
constexpr int rows_per_warp = warp_size / 2;
|
|
1518
2180
|
#pragma unroll
|
|
1519
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
2181
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
2182
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
2183
|
+
// Need if on AMD instead of % because warp_size == 64
|
|
2184
|
+
// This causes double work and throughput loss (MI300X)
|
|
2185
|
+
// H100 loses about 100 t/s with 'if' condition over '%'
|
|
2186
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
|
|
2187
|
+
if (i < mmq_y) {
|
|
2188
|
+
#else
|
|
2189
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
|
|
2190
|
+
{
|
|
2191
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2192
|
+
if (need_check) {
|
|
2193
|
+
i = min(i, i_max);
|
|
2194
|
+
}
|
|
1525
2195
|
|
|
1526
|
-
|
|
2196
|
+
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
|
|
1527
2197
|
|
|
1528
|
-
|
|
1529
|
-
|
|
2198
|
+
const int * scales = (const int *) bxi->scales;
|
|
2199
|
+
const int ksc = threadIdx.x % 2;
|
|
1530
2200
|
|
|
1531
|
-
|
|
1532
|
-
|
|
2201
|
+
const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
|
|
2202
|
+
const int m32 = unpack_scales_q45_K(scales, ksc + 2);
|
|
1533
2203
|
|
|
1534
|
-
|
|
1535
|
-
|
|
2204
|
+
const uint8_t * sc8 = (const uint8_t *) &sc32;
|
|
2205
|
+
const uint8_t * m8 = (const uint8_t *) &m32;
|
|
1536
2206
|
|
|
1537
|
-
|
|
2207
|
+
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
|
|
1538
2208
|
|
|
1539
2209
|
#pragma unroll
|
|
1540
|
-
|
|
1541
|
-
|
|
2210
|
+
for (int l = 0; l < int(sizeof(int)); ++l) {
|
|
2211
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
|
|
2212
|
+
}
|
|
1542
2213
|
}
|
|
1543
2214
|
}
|
|
1544
|
-
|
|
1545
2215
|
#else
|
|
1546
|
-
|
|
1547
2216
|
#pragma unroll
|
|
1548
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1549
|
-
int i = (i0 + threadIdx.y*
|
|
2217
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
2218
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
1550
2219
|
|
|
1551
2220
|
if (need_check) {
|
|
1552
2221
|
i = min(i, i_max);
|
|
@@ -1557,9 +2226,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1557
2226
|
x_dm[i] = bxi->dm;
|
|
1558
2227
|
}
|
|
1559
2228
|
|
|
2229
|
+
constexpr int rows_per_warp = warp_size / 4;
|
|
1560
2230
|
#pragma unroll
|
|
1561
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1562
|
-
int i = (i0 + threadIdx.y*
|
|
2231
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
2232
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
|
1563
2233
|
|
|
1564
2234
|
if (need_check) {
|
|
1565
2235
|
i = min(i, i_max);
|
|
@@ -1569,17 +2239,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1569
2239
|
|
|
1570
2240
|
const int * scales = (const int *) bxi->scales;
|
|
1571
2241
|
|
|
1572
|
-
const int ksc = threadIdx.x % (
|
|
2242
|
+
const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
|
|
1573
2243
|
const int scales8 = unpack_scales_q45_K(scales, ksc);
|
|
1574
2244
|
|
|
1575
|
-
x_sc[i*(
|
|
2245
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
|
1576
2246
|
}
|
|
1577
|
-
#endif //
|
|
2247
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1578
2248
|
}
|
|
1579
2249
|
|
|
1580
|
-
template <int mmq_x, int mmq_y
|
|
2250
|
+
template <int mmq_x, int mmq_y>
|
|
1581
2251
|
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
1582
2252
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
2253
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2254
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1583
2255
|
|
|
1584
2256
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
|
1585
2257
|
const int * x_qs = (const int *) x;
|
|
@@ -1589,7 +2261,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
|
1589
2261
|
const half2 * y_ds = (const half2 *) y;
|
|
1590
2262
|
|
|
1591
2263
|
// #pragma unroll
|
|
1592
|
-
for (int k01 = 0; k01 <
|
|
2264
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
|
|
1593
2265
|
const int k0 = k00 + k01;
|
|
1594
2266
|
|
|
1595
2267
|
#pragma unroll
|
|
@@ -1597,36 +2269,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
|
1597
2269
|
const int j = j0 + threadIdx.y;
|
|
1598
2270
|
|
|
1599
2271
|
#pragma unroll
|
|
1600
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
2272
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1601
2273
|
const int i = i0 + threadIdx.x;
|
|
1602
2274
|
|
|
1603
|
-
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (
|
|
2275
|
+
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
|
|
1604
2276
|
|
|
1605
|
-
sum[j0/nwarps*mmq_y/
|
|
1606
|
-
&x_qs[i*(QR5_K*
|
|
2277
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
|
|
2278
|
+
&x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
|
|
1607
2279
|
x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1608
2280
|
}
|
|
1609
2281
|
}
|
|
1610
2282
|
}
|
|
1611
2283
|
}
|
|
1612
2284
|
|
|
1613
|
-
template <int mmq_y,
|
|
2285
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
|
1614
2286
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2287
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2288
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1615
2289
|
|
|
1616
|
-
#
|
|
2290
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1617
2291
|
int * x_qs = (int *) x_tile;
|
|
1618
|
-
float * x_df = (float *) (x_qs +
|
|
1619
|
-
int * x_sc = (int *) (x_df +
|
|
2292
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2293
|
+
int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
|
|
1620
2294
|
#else
|
|
1621
2295
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
|
1622
2296
|
int * x_qs = (int *) x_tile;
|
|
1623
2297
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1624
2298
|
int * x_sc = (int *) (x_df + txs.dm);
|
|
1625
|
-
#endif //
|
|
2299
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2300
|
+
|
|
2301
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
|
|
2302
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2303
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
1626
2304
|
|
|
1627
2305
|
#pragma unroll
|
|
1628
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
1629
|
-
int i = i0 + threadIdx.y;
|
|
2306
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
2307
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
1630
2308
|
|
|
1631
2309
|
if (need_check) {
|
|
1632
2310
|
i = min(i, i_max);
|
|
@@ -1634,67 +2312,67 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1634
2312
|
|
|
1635
2313
|
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
|
1636
2314
|
|
|
1637
|
-
const int ql = get_int_b2(bxi->ql,
|
|
2315
|
+
const int ql = get_int_b2(bxi->ql, txi);
|
|
1638
2316
|
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
|
1639
2317
|
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
|
1640
2318
|
|
|
1641
|
-
const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (
|
|
1642
|
-
const int qh0 = ((qh >> ((
|
|
1643
|
-
const int qh1 = (qh >> ((
|
|
2319
|
+
const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
|
|
2320
|
+
const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
|
|
2321
|
+
const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
|
|
1644
2322
|
|
|
1645
|
-
const int kq0 = 2*
|
|
1646
|
-
const int kq1 = 2*
|
|
2323
|
+
const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
|
|
2324
|
+
const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
|
|
1647
2325
|
|
|
1648
|
-
#
|
|
2326
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1649
2327
|
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
|
1650
2328
|
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
|
1651
2329
|
#else
|
|
1652
|
-
x_qs[i*(2*
|
|
1653
|
-
x_qs[i*(2*
|
|
1654
|
-
#endif //
|
|
2330
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
|
2331
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
|
2332
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1655
2333
|
}
|
|
1656
2334
|
|
|
1657
|
-
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
|
1658
|
-
const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
|
1659
|
-
|
|
1660
2335
|
#pragma unroll
|
|
1661
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
1662
|
-
int i = (i0 + threadIdx.y
|
|
2336
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
2337
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
1663
2338
|
|
|
1664
2339
|
if (need_check) {
|
|
1665
2340
|
i = min(i, i_max);
|
|
1666
2341
|
}
|
|
1667
2342
|
|
|
1668
|
-
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride
|
|
2343
|
+
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
|
1669
2344
|
|
|
1670
|
-
#
|
|
1671
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q6_K
|
|
2345
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2346
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
|
|
1672
2347
|
#else
|
|
1673
|
-
x_df[i*(
|
|
1674
|
-
#endif //
|
|
2348
|
+
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
|
|
2349
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1675
2350
|
}
|
|
1676
2351
|
|
|
2352
|
+
constexpr int rows_per_warp = warp_size / 4;
|
|
1677
2353
|
#pragma unroll
|
|
1678
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
1679
|
-
int i = (i0 + threadIdx.y
|
|
2354
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
2355
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
|
1680
2356
|
|
|
1681
2357
|
if (need_check) {
|
|
1682
2358
|
i = min(i, i_max);
|
|
1683
2359
|
}
|
|
1684
2360
|
|
|
1685
|
-
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (
|
|
2361
|
+
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
|
|
1686
2362
|
|
|
1687
|
-
#
|
|
1688
|
-
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x
|
|
2363
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2364
|
+
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
|
|
1689
2365
|
#else
|
|
1690
|
-
x_sc[i*(
|
|
1691
|
-
#endif //
|
|
2366
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
|
|
2367
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1692
2368
|
}
|
|
1693
2369
|
}
|
|
1694
2370
|
|
|
1695
|
-
template <int mmq_x, int mmq_y
|
|
2371
|
+
template <int mmq_x, int mmq_y>
|
|
1696
2372
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
1697
2373
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
2374
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2375
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1698
2376
|
|
|
1699
2377
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
|
1700
2378
|
const int * x_qs = (const int *) x;
|
|
@@ -1704,7 +2382,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
|
1704
2382
|
const float * y_df = (const float *) y;
|
|
1705
2383
|
|
|
1706
2384
|
// #pragma unroll
|
|
1707
|
-
for (int k01 = 0; k01 <
|
|
2385
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
|
|
1708
2386
|
const int k0 = k00 + k01;
|
|
1709
2387
|
|
|
1710
2388
|
#pragma unroll
|
|
@@ -1712,23 +2390,126 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
|
1712
2390
|
const int j = j0 + threadIdx.y;
|
|
1713
2391
|
|
|
1714
2392
|
#pragma unroll
|
|
1715
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
2393
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1716
2394
|
const int i = i0 + threadIdx.x;
|
|
1717
2395
|
|
|
1718
|
-
const int8_t * sc = ((const int8_t *) &x_sc[i * (
|
|
2396
|
+
const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
|
|
1719
2397
|
|
|
1720
|
-
sum[j0/nwarps*mmq_y/
|
|
1721
|
-
&x_qs[i*(QR6_K*
|
|
1722
|
-
x_df[i*(
|
|
2398
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
|
|
2399
|
+
&x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
|
|
2400
|
+
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1723
2401
|
}
|
|
1724
2402
|
}
|
|
1725
2403
|
}
|
|
1726
2404
|
}
|
|
1727
2405
|
|
|
1728
|
-
template <int mmq_x, int mmq_y
|
|
2406
|
+
template <int mmq_x, int mmq_y>
|
|
1729
2407
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
1730
2408
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1731
|
-
#
|
|
2409
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
2410
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
2411
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
2412
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
2413
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
2414
|
+
typedef tile<64, 2, int, input_layout> tile_load;
|
|
2415
|
+
|
|
2416
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2417
|
+
constexpr int rows_per_warp = granularity;
|
|
2418
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2419
|
+
|
|
2420
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
2421
|
+
|
|
2422
|
+
const int * x_qs = (const int *) x;
|
|
2423
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
2424
|
+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
2425
|
+
const int * y_qs = (const int *) y + 4;
|
|
2426
|
+
const float * y_df = (const float *) y;
|
|
2427
|
+
|
|
2428
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
2429
|
+
|
|
2430
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
2431
|
+
const int k0 = k00 + k01;
|
|
2432
|
+
|
|
2433
|
+
tile_A A[ntx];
|
|
2434
|
+
#pragma unroll
|
|
2435
|
+
for (int n = 0; n < ntx; ++n) {
|
|
2436
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
2437
|
+
}
|
|
2438
|
+
|
|
2439
|
+
#pragma unroll
|
|
2440
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
2441
|
+
tile_B B[1];
|
|
2442
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
2443
|
+
|
|
2444
|
+
const int j = j0 + tile_C::get_j(0);
|
|
2445
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
2446
|
+
|
|
2447
|
+
#pragma unroll
|
|
2448
|
+
for (int n = 0; n < ntx; ++n) {
|
|
2449
|
+
tile_C C;
|
|
2450
|
+
mma(C, A[n], B[0]);
|
|
2451
|
+
|
|
2452
|
+
#pragma unroll
|
|
2453
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
2454
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
2455
|
+
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
|
|
2456
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
|
|
2457
|
+
}
|
|
2458
|
+
}
|
|
2459
|
+
}
|
|
2460
|
+
}
|
|
2461
|
+
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
2462
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
2463
|
+
typedef tile<16, 4, int, input_layout> tile_A;
|
|
2464
|
+
typedef tile<16, 4, int, input_layout> tile_B;
|
|
2465
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
2466
|
+
|
|
2467
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2468
|
+
constexpr int rows_per_warp = granularity;
|
|
2469
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2470
|
+
|
|
2471
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
2472
|
+
|
|
2473
|
+
const int * x_qs = (const int *) x;
|
|
2474
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
2475
|
+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
2476
|
+
const int * y_qs = (const int *) y + 4;
|
|
2477
|
+
const float * y_df = (const float *) y;
|
|
2478
|
+
|
|
2479
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
2480
|
+
|
|
2481
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
2482
|
+
const int k0 = k00 + k01;
|
|
2483
|
+
|
|
2484
|
+
tile_A A[ntx];
|
|
2485
|
+
#pragma unroll
|
|
2486
|
+
for (int n = 0; n < ntx; ++n) {
|
|
2487
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
2488
|
+
}
|
|
2489
|
+
|
|
2490
|
+
#pragma unroll
|
|
2491
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
2492
|
+
tile_B B;
|
|
2493
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
2494
|
+
|
|
2495
|
+
const int j = j0 + tile_C::get_j(0);
|
|
2496
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
2497
|
+
|
|
2498
|
+
#pragma unroll
|
|
2499
|
+
for (int n = 0; n < ntx; ++n) {
|
|
2500
|
+
tile_C C;
|
|
2501
|
+
mma(C, A[n], B);
|
|
2502
|
+
|
|
2503
|
+
#pragma unroll
|
|
2504
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
2505
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
2506
|
+
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
|
|
2507
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
|
|
2508
|
+
}
|
|
2509
|
+
}
|
|
2510
|
+
}
|
|
2511
|
+
}
|
|
2512
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
|
1732
2513
|
|
|
1733
2514
|
typedef tile<16, 4, int> tile_A;
|
|
1734
2515
|
typedef tile< 8, 4, int> tile_B;
|
|
@@ -1738,11 +2519,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1738
2519
|
constexpr int rows_per_warp = 2 * granularity;
|
|
1739
2520
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1740
2521
|
|
|
1741
|
-
y += (threadIdx.y % ntx) * (
|
|
2522
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1742
2523
|
|
|
1743
2524
|
const int * x_qs = (const int *) x;
|
|
1744
|
-
const float * x_df = (const float *) x_qs +
|
|
1745
|
-
const int * x_sc = (const int *) x_df +
|
|
2525
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
2526
|
+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
1746
2527
|
const int * y_qs = (const int *) y + 4;
|
|
1747
2528
|
const float * y_df = (const float *) y;
|
|
1748
2529
|
|
|
@@ -1755,7 +2536,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1755
2536
|
#pragma unroll
|
|
1756
2537
|
for (int n = 0; n < ntx; ++n) {
|
|
1757
2538
|
#pragma unroll
|
|
1758
|
-
for (int k01 = 0; k01 <
|
|
2539
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
|
1759
2540
|
const int k0 = k00 + k01;
|
|
1760
2541
|
|
|
1761
2542
|
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
|
@@ -1763,7 +2544,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1763
2544
|
}
|
|
1764
2545
|
|
|
1765
2546
|
#pragma unroll
|
|
1766
|
-
for (int k01 = 0; k01 <
|
|
2547
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
|
|
1767
2548
|
const int k0 = k00 + k01;
|
|
1768
2549
|
|
|
1769
2550
|
#pragma unroll
|
|
@@ -1793,7 +2574,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1793
2574
|
float tmp[ntx][tile_C::ne] = {{0.0f}};
|
|
1794
2575
|
|
|
1795
2576
|
#pragma unroll
|
|
1796
|
-
for (int k01 = 0; k01 <
|
|
2577
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
|
1797
2578
|
tile_B B[2];
|
|
1798
2579
|
float dB[tile_C::ne/2];
|
|
1799
2580
|
|
|
@@ -1830,29 +2611,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1830
2611
|
}
|
|
1831
2612
|
}
|
|
1832
2613
|
#else
|
|
1833
|
-
|
|
2614
|
+
GGML_UNUSED_VARS(x, y, sum, k00);
|
|
1834
2615
|
NO_DEVICE_CODE;
|
|
1835
|
-
#endif //
|
|
2616
|
+
#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
|
|
1836
2617
|
}
|
|
1837
2618
|
|
|
1838
|
-
template <int mmq_y,
|
|
2619
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
|
1839
2620
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2621
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2622
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1840
2623
|
|
|
1841
|
-
#
|
|
2624
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1842
2625
|
int * x_qs = (int *) x_tile;
|
|
1843
|
-
float * x_df = (float *) (x_qs +
|
|
2626
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1844
2627
|
#else
|
|
1845
2628
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
|
1846
2629
|
int * x_qs = (int *) x_tile;
|
|
1847
2630
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1848
|
-
#endif //
|
|
2631
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1849
2632
|
|
|
1850
|
-
|
|
1851
|
-
|
|
2633
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
|
|
2634
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2635
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
2636
|
+
const int kbx = txi / QI4_NL;
|
|
2637
|
+
const int kqsx = txi % QI4_NL;
|
|
1852
2638
|
|
|
1853
2639
|
#pragma unroll
|
|
1854
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
1855
|
-
int i = i0 + threadIdx.y;
|
|
2640
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
2641
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
1856
2642
|
|
|
1857
2643
|
if (need_check) {
|
|
1858
2644
|
i = min(i, i_max);
|
|
@@ -1861,23 +2647,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1861
2647
|
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
|
1862
2648
|
|
|
1863
2649
|
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
|
1864
|
-
const int2 v = get_int_from_table_16(aux_q4);
|
|
1865
|
-
const int k0 =
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 +
|
|
2650
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
|
2651
|
+
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
|
2652
|
+
|
|
2653
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2654
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
|
2655
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
|
|
1869
2656
|
#else
|
|
1870
|
-
x_qs[i*(2*
|
|
1871
|
-
x_qs[i*(2*
|
|
1872
|
-
#endif //
|
|
2657
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
|
2658
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
|
|
2659
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1873
2660
|
}
|
|
1874
2661
|
|
|
1875
|
-
|
|
2662
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
|
|
2663
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
1876
2664
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
1877
2665
|
|
|
1878
2666
|
#pragma unroll
|
|
1879
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
1880
|
-
int i = i0 + threadIdx.y *
|
|
2667
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
2668
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
1881
2669
|
|
|
1882
2670
|
if (need_check) {
|
|
1883
2671
|
i = min(i, i_max);
|
|
@@ -1885,31 +2673,35 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1885
2673
|
|
|
1886
2674
|
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
|
1887
2675
|
|
|
1888
|
-
#
|
|
1889
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2676
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2677
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
|
1890
2678
|
#else
|
|
1891
|
-
x_df[i*(
|
|
1892
|
-
#endif //
|
|
2679
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
|
|
2680
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1893
2681
|
}
|
|
1894
2682
|
}
|
|
1895
2683
|
|
|
1896
|
-
template <int mmq_y,
|
|
2684
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
|
|
1897
2685
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2686
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2687
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1898
2688
|
|
|
1899
|
-
#
|
|
2689
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1900
2690
|
int * x_qs = (int *) x_tile;
|
|
1901
|
-
float * x_df = (float *) (x_qs +
|
|
2691
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1902
2692
|
#else
|
|
1903
2693
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
|
|
1904
2694
|
int * x_qs = (int *) x_tile;
|
|
1905
2695
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1906
|
-
#endif //
|
|
2696
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1907
2697
|
|
|
1908
|
-
|
|
2698
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
|
|
2699
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2700
|
+
const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
1909
2701
|
|
|
1910
2702
|
#pragma unroll
|
|
1911
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
1912
|
-
int i = i0 + threadIdx.y*
|
|
2703
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2704
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
1913
2705
|
|
|
1914
2706
|
if (need_check) {
|
|
1915
2707
|
i = min(i, i_max);
|
|
@@ -1932,42 +2724,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1932
2724
|
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
|
1933
2725
|
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
|
1934
2726
|
|
|
1935
|
-
#
|
|
2727
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1936
2728
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
|
1937
2729
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
|
|
1938
2730
|
#else
|
|
1939
|
-
x_qs[i*(2*
|
|
1940
|
-
x_qs[i*(2*
|
|
1941
|
-
#endif //
|
|
2731
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
|
|
2732
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
|
|
2733
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1942
2734
|
}
|
|
1943
2735
|
|
|
1944
2736
|
const int ls = aux32 >> 28;
|
|
1945
2737
|
const float d = bxi->d;
|
|
1946
|
-
#
|
|
1947
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2738
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2739
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
|
1948
2740
|
#else
|
|
1949
|
-
x_df[i*(
|
|
1950
|
-
#endif //
|
|
2741
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
|
2742
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1951
2743
|
}
|
|
1952
2744
|
}
|
|
1953
2745
|
|
|
1954
|
-
template <int mmq_y,
|
|
2746
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
|
|
1955
2747
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2748
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2749
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1956
2750
|
|
|
1957
|
-
#
|
|
2751
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1958
2752
|
int * x_qs = (int *) x_tile;
|
|
1959
|
-
float * x_df = (float *) (x_qs +
|
|
2753
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1960
2754
|
#else
|
|
1961
2755
|
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
|
1962
2756
|
int * x_qs = (int *) x_tile;
|
|
1963
2757
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1964
|
-
#endif //
|
|
2758
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1965
2759
|
|
|
1966
|
-
|
|
2760
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
|
|
2761
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2762
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
1967
2763
|
|
|
1968
2764
|
#pragma unroll
|
|
1969
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
1970
|
-
int i = i0 + threadIdx.y*
|
|
2765
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2766
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
1971
2767
|
|
|
1972
2768
|
if (need_check) {
|
|
1973
2769
|
i = min(i, i_max);
|
|
@@ -1986,44 +2782,47 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1986
2782
|
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
|
1987
2783
|
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
|
1988
2784
|
|
|
1989
|
-
#
|
|
2785
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1990
2786
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
|
1991
2787
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
|
1992
2788
|
#else
|
|
1993
|
-
x_qs[i*(2*
|
|
1994
|
-
x_qs[i*(2*
|
|
1995
|
-
#endif //
|
|
2789
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2790
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2791
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1996
2792
|
}
|
|
1997
2793
|
|
|
1998
2794
|
const int ls = bxi->scales[kqsx];
|
|
1999
2795
|
const float d = bxi->d;
|
|
2000
|
-
#
|
|
2001
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
|
2002
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
|
2796
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2797
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2798
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2003
2799
|
#else
|
|
2004
|
-
x_df[i*(2*
|
|
2005
|
-
x_df[i*(2*
|
|
2006
|
-
#endif //
|
|
2800
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2801
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2802
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2007
2803
|
}
|
|
2008
2804
|
}
|
|
2009
2805
|
|
|
2010
|
-
template <int mmq_y,
|
|
2806
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
|
|
2011
2807
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2808
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2809
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2012
2810
|
|
|
2013
|
-
#
|
|
2811
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2014
2812
|
int * x_qs = (int *) x_tile;
|
|
2015
|
-
float * x_df = (float *) (x_qs +
|
|
2813
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2016
2814
|
#else
|
|
2017
2815
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
|
|
2018
2816
|
int * x_qs = (int *) x_tile;
|
|
2019
2817
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2020
|
-
#endif //
|
|
2021
|
-
|
|
2022
|
-
|
|
2818
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2819
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
|
|
2820
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2821
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2023
2822
|
|
|
2024
2823
|
#pragma unroll
|
|
2025
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2026
|
-
int i = i0 + threadIdx.y*
|
|
2824
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2825
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
2027
2826
|
|
|
2028
2827
|
if (need_check) {
|
|
2029
2828
|
i = min(i, i_max);
|
|
@@ -2049,44 +2848,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2049
2848
|
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
|
|
2050
2849
|
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
|
|
2051
2850
|
|
|
2052
|
-
#
|
|
2851
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2053
2852
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2054
2853
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2055
2854
|
#else
|
|
2056
|
-
x_qs[i*(2*
|
|
2057
|
-
x_qs[i*(2*
|
|
2058
|
-
#endif //
|
|
2855
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2856
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2857
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2059
2858
|
}
|
|
2060
2859
|
|
|
2061
2860
|
const int ls = bxi->scales[kqsx];
|
|
2062
2861
|
const float d = bxi->d;
|
|
2063
|
-
#
|
|
2064
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
|
2065
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
|
2862
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2863
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2864
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2066
2865
|
#else
|
|
2067
|
-
x_df[i*(2*
|
|
2068
|
-
x_df[i*(2*
|
|
2069
|
-
#endif //
|
|
2866
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2867
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2868
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2070
2869
|
}
|
|
2071
2870
|
}
|
|
2072
2871
|
|
|
2073
|
-
template <int mmq_y,
|
|
2872
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
|
|
2074
2873
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2874
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2875
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2075
2876
|
|
|
2076
|
-
#
|
|
2877
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2077
2878
|
int * x_qs = (int *) x_tile;
|
|
2078
|
-
float * x_df = (float *) (x_qs +
|
|
2879
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2079
2880
|
#else
|
|
2080
2881
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
|
|
2081
2882
|
int * x_qs = (int *) x_tile;
|
|
2082
2883
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2083
|
-
#endif //
|
|
2884
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2084
2885
|
|
|
2085
|
-
|
|
2886
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
|
|
2887
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2888
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2086
2889
|
|
|
2087
2890
|
#pragma unroll
|
|
2088
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2089
|
-
int i = i0 + threadIdx.y*
|
|
2891
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2892
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
2090
2893
|
|
|
2091
2894
|
if (need_check) {
|
|
2092
2895
|
i = min(i, i_max);
|
|
@@ -2107,42 +2910,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2107
2910
|
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
|
2108
2911
|
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
|
2109
2912
|
|
|
2110
|
-
#
|
|
2913
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2111
2914
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2112
2915
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2113
2916
|
#else
|
|
2114
|
-
x_qs[i*(2*
|
|
2115
|
-
x_qs[i*(2*
|
|
2116
|
-
#endif //
|
|
2917
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2918
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2919
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2117
2920
|
}
|
|
2118
2921
|
|
|
2119
2922
|
const int ls = aux32 >> 28;
|
|
2120
2923
|
const float d = bxi->d;
|
|
2121
|
-
#
|
|
2122
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2924
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2925
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
|
|
2123
2926
|
#else
|
|
2124
|
-
x_df[i*(
|
|
2125
|
-
#endif //
|
|
2927
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
|
|
2928
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2126
2929
|
}
|
|
2127
2930
|
}
|
|
2128
2931
|
|
|
2129
|
-
template <int mmq_y,
|
|
2932
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
|
|
2130
2933
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2934
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2935
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2131
2936
|
|
|
2132
|
-
#
|
|
2937
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2133
2938
|
int * x_qs = (int *) x_tile;
|
|
2134
|
-
float * x_df = (float *) (x_qs +
|
|
2939
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2135
2940
|
#else
|
|
2136
2941
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
|
2137
2942
|
int * x_qs = (int *) x_tile;
|
|
2138
2943
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2139
|
-
#endif //
|
|
2944
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2140
2945
|
|
|
2141
|
-
|
|
2946
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
|
|
2947
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2948
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2142
2949
|
|
|
2143
2950
|
#pragma unroll
|
|
2144
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2145
|
-
int i = i0 + threadIdx.y*
|
|
2951
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2952
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
2146
2953
|
|
|
2147
2954
|
if (need_check) {
|
|
2148
2955
|
i = min(i, i_max);
|
|
@@ -2170,42 +2977,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2170
2977
|
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
|
2171
2978
|
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
|
2172
2979
|
|
|
2173
|
-
#
|
|
2980
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2174
2981
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
|
|
2175
2982
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
|
|
2176
2983
|
#else
|
|
2177
|
-
x_qs[i*(2*
|
|
2178
|
-
x_qs[i*(2*
|
|
2179
|
-
#endif //
|
|
2984
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
|
|
2985
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
|
|
2986
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2180
2987
|
}
|
|
2181
2988
|
|
|
2182
2989
|
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
|
|
2183
2990
|
const float d = bxi->d;
|
|
2184
|
-
#
|
|
2185
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2991
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2992
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
|
|
2186
2993
|
#else
|
|
2187
|
-
x_df[i*(
|
|
2188
|
-
#endif //
|
|
2994
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
|
|
2995
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2189
2996
|
}
|
|
2190
2997
|
}
|
|
2191
2998
|
|
|
2192
|
-
template <int mmq_y,
|
|
2999
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
|
|
2193
3000
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
3001
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3002
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2194
3003
|
|
|
2195
|
-
#
|
|
3004
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2196
3005
|
int * x_qs = (int *) x_tile;
|
|
2197
|
-
half2 * x_ds = (half2 *) (x_qs +
|
|
3006
|
+
half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2198
3007
|
#else
|
|
2199
3008
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
|
2200
3009
|
int * x_qs = (int *) x_tile;
|
|
2201
3010
|
half2 * x_ds = (half2 *) (x_qs + txs.qs);
|
|
2202
|
-
#endif //
|
|
3011
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2203
3012
|
|
|
2204
|
-
|
|
3013
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
|
|
3014
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
3015
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2205
3016
|
|
|
2206
3017
|
#pragma unroll
|
|
2207
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2208
|
-
int i = i0 + threadIdx.y*
|
|
3018
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
3019
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
2209
3020
|
|
|
2210
3021
|
if (need_check) {
|
|
2211
3022
|
i = min(i, i_max);
|
|
@@ -2225,66 +3036,71 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2225
3036
|
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
|
|
2226
3037
|
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
|
|
2227
3038
|
|
|
2228
|
-
#
|
|
3039
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2229
3040
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
|
|
2230
3041
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
|
|
2231
3042
|
#else
|
|
2232
|
-
x_qs[i*(2*
|
|
2233
|
-
x_qs[i*(2*
|
|
2234
|
-
#endif //
|
|
3043
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
|
|
3044
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
|
|
3045
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2235
3046
|
}
|
|
2236
3047
|
|
|
2237
3048
|
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
|
|
2238
3049
|
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
|
|
2239
3050
|
|
|
2240
|
-
#
|
|
2241
|
-
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1
|
|
3051
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
3052
|
+
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
|
|
2242
3053
|
#else
|
|
2243
|
-
x_ds[i*(
|
|
2244
|
-
#endif //
|
|
3054
|
+
x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
|
|
3055
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2245
3056
|
}
|
|
2246
3057
|
}
|
|
2247
3058
|
|
|
2248
|
-
template <int mmq_y,
|
|
3059
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
|
2249
3060
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
3061
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3062
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2250
3063
|
|
|
2251
|
-
#
|
|
3064
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2252
3065
|
int * x_qs = (int *) x_tile;
|
|
2253
|
-
float * x_df = (float *) (x_qs +
|
|
3066
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2254
3067
|
#else
|
|
2255
3068
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
|
2256
3069
|
int * x_qs = (int *) x_tile;
|
|
2257
3070
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2258
|
-
#endif //
|
|
3071
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2259
3072
|
|
|
2260
|
-
|
|
2261
|
-
|
|
3073
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
|
|
3074
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
3075
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2262
3076
|
|
|
2263
3077
|
#pragma unroll
|
|
2264
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
2265
|
-
int i = i0 + threadIdx.y;
|
|
3078
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
3079
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
2266
3080
|
|
|
2267
3081
|
if (need_check) {
|
|
2268
3082
|
i = min(i, i_max);
|
|
2269
3083
|
}
|
|
2270
3084
|
|
|
2271
|
-
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride
|
|
3085
|
+
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
|
2272
3086
|
|
|
2273
3087
|
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
|
2274
|
-
const int2 v = get_int_from_table_16(aux_q4);
|
|
2275
|
-
const int k0 = 8 * (
|
|
2276
|
-
|
|
3088
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
|
3089
|
+
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
|
3090
|
+
|
|
3091
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2277
3092
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
|
2278
3093
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
|
2279
3094
|
#else
|
|
2280
|
-
x_qs[i*(2*
|
|
2281
|
-
x_qs[i*(2*
|
|
2282
|
-
#endif //
|
|
3095
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
|
3096
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
|
|
3097
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2283
3098
|
}
|
|
2284
3099
|
|
|
3100
|
+
constexpr int rows_per_warp = warp_size / 8;
|
|
2285
3101
|
#pragma unroll
|
|
2286
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2287
|
-
int i = i0 + threadIdx.y *
|
|
3102
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
3103
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
|
|
2288
3104
|
|
|
2289
3105
|
if (need_check) {
|
|
2290
3106
|
i = min(i, i_max);
|
|
@@ -2297,18 +3113,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2297
3113
|
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
|
2298
3114
|
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
|
2299
3115
|
|
|
2300
|
-
#
|
|
2301
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
3116
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
3117
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
|
2302
3118
|
#else
|
|
2303
|
-
x_df[i*(
|
|
2304
|
-
#endif //
|
|
3119
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
|
3120
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2305
3121
|
}
|
|
2306
3122
|
}
|
|
2307
3123
|
|
|
2308
|
-
template<int mmq_x, int mmq_y,
|
|
3124
|
+
template<int mmq_x, int mmq_y, bool need_check>
|
|
2309
3125
|
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
|
2310
3126
|
const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
|
|
2311
3127
|
const int stride, const int i_max, const int j_max) {
|
|
3128
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3129
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
3130
|
+
|
|
2312
3131
|
#pragma unroll
|
|
2313
3132
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
2314
3133
|
const int j = j0 + threadIdx.y;
|
|
@@ -2318,32 +3137,42 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
|
|
|
2318
3137
|
}
|
|
2319
3138
|
|
|
2320
3139
|
#pragma unroll
|
|
2321
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
3140
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
2322
3141
|
const int i = i0 + threadIdx.x;
|
|
2323
3142
|
|
|
2324
3143
|
if (need_check && i > i_max) {
|
|
2325
3144
|
continue;
|
|
2326
3145
|
}
|
|
2327
3146
|
|
|
2328
|
-
dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/
|
|
3147
|
+
dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
2329
3148
|
}
|
|
2330
3149
|
}
|
|
2331
3150
|
}
|
|
2332
3151
|
|
|
2333
|
-
template<
|
|
3152
|
+
template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
|
|
2334
3153
|
static __device__ __forceinline__ void mmq_write_back_mma(
|
|
2335
3154
|
const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
|
|
2336
3155
|
const int stride, const int i_max, const int j_max) {
|
|
2337
|
-
typedef tile<16, 8, int> tile_C;
|
|
2338
3156
|
|
|
2339
3157
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
3158
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3159
|
+
|
|
3160
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
3161
|
+
constexpr int tileC_IJ = mmq_get_granularity_device(0);
|
|
3162
|
+
typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
3163
|
+
constexpr int rows_per_warp = granularity;
|
|
3164
|
+
#else
|
|
3165
|
+
typedef tile<16, 8, int> tile_C;
|
|
2340
3166
|
constexpr int rows_per_warp = 2 * granularity;
|
|
3167
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
2341
3168
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2342
3169
|
|
|
2343
3170
|
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
|
2344
|
-
#
|
|
3171
|
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2345
3172
|
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
|
2346
|
-
#
|
|
3173
|
+
#else
|
|
3174
|
+
GGML_UNUSED(nwarps);
|
|
3175
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2347
3176
|
|
|
2348
3177
|
#pragma unroll
|
|
2349
3178
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
@@ -2371,188 +3200,212 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
|
2371
3200
|
|
|
2372
3201
|
// -------------------------------------------------------------------------------------------------------------------------------------
|
|
2373
3202
|
|
|
2374
|
-
template <int mmq_x, int mmq_y,
|
|
3203
|
+
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
|
|
2375
3204
|
struct mmq_type_traits;
|
|
2376
3205
|
|
|
2377
|
-
template <int mmq_x, int mmq_y,
|
|
2378
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3206
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3207
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
|
|
2379
3208
|
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
|
2380
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y,
|
|
2381
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2382
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y
|
|
3209
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
|
|
3210
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
|
|
3211
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2383
3212
|
};
|
|
2384
3213
|
|
|
2385
|
-
template <int mmq_x, int mmq_y,
|
|
2386
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3214
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3215
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
|
|
2387
3216
|
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
|
2388
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y,
|
|
2389
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2390
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y
|
|
3217
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
|
|
3218
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
3219
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2391
3220
|
};
|
|
2392
3221
|
|
|
2393
|
-
template <int mmq_x, int mmq_y,
|
|
2394
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3222
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3223
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
|
|
2395
3224
|
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
|
2396
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y,
|
|
2397
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2398
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
3225
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
|
|
3226
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3227
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2399
3228
|
};
|
|
2400
3229
|
|
|
2401
|
-
template <int mmq_x, int mmq_y,
|
|
2402
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3230
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3231
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
|
|
2403
3232
|
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
|
2404
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y,
|
|
2405
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2406
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y
|
|
3233
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
|
|
3234
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
3235
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2407
3236
|
};
|
|
2408
3237
|
|
|
2409
|
-
template <int mmq_x, int mmq_y,
|
|
2410
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3238
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3239
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
|
2411
3240
|
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
|
2412
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y,
|
|
2413
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2414
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
3241
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
|
|
3242
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3243
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
3244
|
+
};
|
|
3245
|
+
|
|
3246
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3247
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|
3248
|
+
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
|
3249
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
3250
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
|
3251
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
|
|
3252
|
+
#else
|
|
3253
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
|
3254
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3255
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
3256
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2415
3257
|
};
|
|
2416
3258
|
|
|
2417
|
-
template <int mmq_x, int mmq_y,
|
|
2418
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3259
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3260
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
|
2419
3261
|
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
|
2420
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y,
|
|
2421
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y
|
|
2422
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y
|
|
3262
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
|
|
3263
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
|
|
3264
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2423
3265
|
};
|
|
2424
3266
|
|
|
2425
|
-
template <int mmq_x, int mmq_y,
|
|
2426
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3267
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3268
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
|
|
2427
3269
|
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
|
2428
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y,
|
|
2429
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
|
2430
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y
|
|
3270
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
|
|
3271
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
|
3272
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2431
3273
|
};
|
|
2432
3274
|
|
|
2433
|
-
template <int mmq_x, int mmq_y,
|
|
2434
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3275
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3276
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
|
|
2435
3277
|
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
|
2436
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y,
|
|
2437
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2438
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y
|
|
3278
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
|
|
3279
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
3280
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2439
3281
|
};
|
|
2440
3282
|
|
|
2441
|
-
template <int mmq_x, int mmq_y,
|
|
2442
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3283
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3284
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
|
|
2443
3285
|
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
|
2444
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y,
|
|
2445
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2446
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y
|
|
3286
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
|
|
3287
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
3288
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2447
3289
|
};
|
|
2448
3290
|
|
|
2449
|
-
template <int mmq_x, int mmq_y,
|
|
2450
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3291
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3292
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
|
|
2451
3293
|
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
|
2452
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y,
|
|
2453
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y
|
|
2454
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y
|
|
3294
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
|
|
3295
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
|
|
3296
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2455
3297
|
};
|
|
2456
3298
|
|
|
2457
|
-
template <int mmq_x, int mmq_y,
|
|
2458
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3299
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3300
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
|
|
2459
3301
|
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
|
|
2460
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y,
|
|
2461
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2462
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
3302
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
|
|
3303
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3304
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2463
3305
|
};
|
|
2464
3306
|
|
|
2465
|
-
template <int mmq_x, int mmq_y,
|
|
2466
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3307
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3308
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
|
|
2467
3309
|
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
|
|
2468
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y,
|
|
2469
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
|
2470
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y
|
|
3310
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
|
|
3311
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
|
3312
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2471
3313
|
};
|
|
2472
3314
|
|
|
2473
|
-
template <int mmq_x, int mmq_y,
|
|
2474
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3315
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3316
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
|
|
2475
3317
|
static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
|
|
2476
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y,
|
|
2477
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
|
2478
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y
|
|
3318
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
|
|
3319
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
|
3320
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2479
3321
|
};
|
|
2480
3322
|
|
|
2481
|
-
template <int mmq_x, int mmq_y,
|
|
2482
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3323
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3324
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
|
|
2483
3325
|
static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
|
|
2484
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y,
|
|
2485
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2486
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
3326
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
|
|
3327
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3328
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2487
3329
|
};
|
|
2488
3330
|
|
|
2489
|
-
template <int mmq_x, int mmq_y,
|
|
2490
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3331
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3332
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
|
|
2491
3333
|
static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
|
|
2492
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y,
|
|
2493
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2494
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
3334
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
|
|
3335
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3336
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2495
3337
|
};
|
|
2496
3338
|
|
|
2497
|
-
template <int mmq_x, int mmq_y,
|
|
2498
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3339
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3340
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
|
|
2499
3341
|
static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
|
|
2500
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y,
|
|
2501
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2502
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y
|
|
3342
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
|
|
3343
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
3344
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2503
3345
|
};
|
|
2504
3346
|
|
|
2505
|
-
template <int mmq_x, int mmq_y,
|
|
2506
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3347
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3348
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
|
|
2507
3349
|
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
|
|
2508
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y,
|
|
2509
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2510
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
3350
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
|
|
3351
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3352
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2511
3353
|
};
|
|
2512
3354
|
|
|
2513
|
-
template <int mmq_x, int mmq_y,
|
|
2514
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
3355
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
3356
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
|
|
2515
3357
|
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
|
2516
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y,
|
|
2517
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2518
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
3358
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
|
|
3359
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3360
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2519
3361
|
};
|
|
2520
3362
|
|
|
2521
|
-
template <ggml_type type, int mmq_x,
|
|
3363
|
+
template <ggml_type type, int mmq_x, bool need_check, bool fixup>
|
|
2522
3364
|
static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
2523
3365
|
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
|
|
2524
3366
|
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
|
2525
3367
|
const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
|
2526
3368
|
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
|
|
2527
3369
|
|
|
3370
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
3371
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2528
3372
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
2529
3373
|
constexpr int mmq_y = get_mmq_y_device();
|
|
2530
|
-
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y,
|
|
3374
|
+
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
|
|
2531
3375
|
|
|
2532
3376
|
extern __shared__ int data_mul_mat_q[];
|
|
2533
3377
|
int * tile_y = data_mul_mat_q + mmq_x;
|
|
2534
|
-
int * tile_x = tile_y + GGML_PAD(mmq_x*
|
|
3378
|
+
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
|
|
3379
|
+
|
|
3380
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
3381
|
+
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
|
|
3382
|
+
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
|
|
3383
|
+
#else
|
|
3384
|
+
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
|
|
3385
|
+
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
|
3386
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2535
3387
|
|
|
2536
|
-
#
|
|
2537
|
-
|
|
2538
|
-
constexpr
|
|
3388
|
+
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
3389
|
+
// FP4 tile stores 8 blocks
|
|
3390
|
+
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
|
|
2539
3391
|
#else
|
|
2540
|
-
constexpr
|
|
2541
|
-
|
|
2542
|
-
|
|
3392
|
+
constexpr int ne_block = 4 * QK8_1;
|
|
3393
|
+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
3394
|
+
|
|
3395
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3396
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
2543
3397
|
|
|
2544
|
-
|
|
3398
|
+
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
|
2545
3399
|
|
|
2546
|
-
|
|
3400
|
+
constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
|
|
2547
3401
|
|
|
2548
3402
|
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
|
2549
3403
|
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
|
2550
|
-
|
|
2551
3404
|
{
|
|
2552
|
-
const int * by0 = y + ncols_y*(kb0*
|
|
3405
|
+
const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
|
|
2553
3406
|
#pragma unroll
|
|
2554
|
-
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*
|
|
2555
|
-
int l = l0 + threadIdx.y*
|
|
3407
|
+
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
|
|
3408
|
+
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2556
3409
|
|
|
2557
3410
|
tile_y[l] = by0[l];
|
|
2558
3411
|
}
|
|
@@ -2565,10 +3418,10 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
2565
3418
|
__syncthreads();
|
|
2566
3419
|
|
|
2567
3420
|
{
|
|
2568
|
-
const int * by0 = y + ncols_y*(kb0*
|
|
3421
|
+
const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
|
|
2569
3422
|
#pragma unroll
|
|
2570
|
-
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*
|
|
2571
|
-
int l = l0 + threadIdx.y*
|
|
3423
|
+
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
|
|
3424
|
+
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2572
3425
|
|
|
2573
3426
|
tile_y[l] = by0[l];
|
|
2574
3427
|
}
|
|
@@ -2576,7 +3429,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
2576
3429
|
|
|
2577
3430
|
__syncthreads();
|
|
2578
3431
|
|
|
2579
|
-
vec_dot(tile_x, tile_y, sum,
|
|
3432
|
+
vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
|
|
2580
3433
|
|
|
2581
3434
|
__syncthreads();
|
|
2582
3435
|
}
|
|
@@ -2591,24 +3444,25 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
2591
3444
|
|
|
2592
3445
|
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
|
|
2593
3446
|
|
|
2594
|
-
template <ggml_type type, int mmq_x,
|
|
2595
|
-
#if defined(GGML_USE_HIP)
|
|
3447
|
+
template <ggml_type type, int mmq_x, bool need_check>
|
|
3448
|
+
#if defined(GGML_USE_HIP)
|
|
2596
3449
|
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
|
2597
|
-
__launch_bounds__(
|
|
3450
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
|
|
2598
3451
|
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
|
2599
3452
|
#else
|
|
2600
3453
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
2601
|
-
__launch_bounds__(
|
|
3454
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
|
|
2602
3455
|
#else
|
|
2603
|
-
__launch_bounds__(
|
|
3456
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
|
|
2604
3457
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
2605
|
-
#endif // defined(GGML_USE_HIP)
|
|
3458
|
+
#endif // defined(GGML_USE_HIP)
|
|
2606
3459
|
static __global__ void mul_mat_q(
|
|
2607
3460
|
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
|
2608
3461
|
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
|
2609
3462
|
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
|
2610
3463
|
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
2611
|
-
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst
|
|
3464
|
+
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
|
3465
|
+
const int ncols_max) {
|
|
2612
3466
|
|
|
2613
3467
|
// Skip unused template specializations for faster compilation:
|
|
2614
3468
|
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
|
@@ -2616,10 +3470,13 @@ static __global__ void mul_mat_q(
|
|
|
2616
3470
|
return;
|
|
2617
3471
|
}
|
|
2618
3472
|
|
|
3473
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3474
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
3475
|
+
|
|
2619
3476
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
2620
3477
|
constexpr int mmq_y = get_mmq_y_device();
|
|
2621
3478
|
|
|
2622
|
-
const int ntx = (
|
|
3479
|
+
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
|
|
2623
3480
|
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
|
2624
3481
|
|
|
2625
3482
|
// Initialize the ids for writing back data with just the index.
|
|
@@ -2627,10 +3484,10 @@ static __global__ void mul_mat_q(
|
|
|
2627
3484
|
// For MoE the correct indices are loaded from ids_dst.
|
|
2628
3485
|
extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
|
|
2629
3486
|
#pragma unroll
|
|
2630
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
|
2631
|
-
const int j = j0 + threadIdx.y*
|
|
3487
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
|
3488
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2632
3489
|
|
|
2633
|
-
if (j0 + nwarps*
|
|
3490
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
|
2634
3491
|
break;
|
|
2635
3492
|
}
|
|
2636
3493
|
|
|
@@ -2638,8 +3495,8 @@ static __global__ void mul_mat_q(
|
|
|
2638
3495
|
}
|
|
2639
3496
|
__syncthreads();
|
|
2640
3497
|
|
|
2641
|
-
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
|
2642
|
-
#if (defined(GGML_USE_HIP) && defined(
|
|
3498
|
+
// On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
|
3499
|
+
#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
2643
3500
|
{
|
|
2644
3501
|
const int wt = blockIdx.z / nchannels_y;
|
|
2645
3502
|
const int zt = blockIdx.z - wt*nchannels_y;
|
|
@@ -2667,10 +3524,10 @@ static __global__ void mul_mat_q(
|
|
|
2667
3524
|
|
|
2668
3525
|
// __syncthreads(); // There is no previous tile that could cause a race condition.
|
|
2669
3526
|
#pragma unroll
|
|
2670
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
|
2671
|
-
const int j = j0 + threadIdx.y*
|
|
3527
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
|
3528
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2672
3529
|
|
|
2673
|
-
if (j0 + nwarps*
|
|
3530
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
|
2674
3531
|
break;
|
|
2675
3532
|
}
|
|
2676
3533
|
|
|
@@ -2688,15 +3545,17 @@ static __global__ void mul_mat_q(
|
|
|
2688
3545
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
2689
3546
|
|
|
2690
3547
|
constexpr bool fixup = false;
|
|
2691
|
-
mul_mat_q_process_tile<type, mmq_x,
|
|
3548
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
2692
3549
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
2693
3550
|
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
|
|
2694
3551
|
return;
|
|
2695
3552
|
}
|
|
2696
|
-
#endif // (defined(GGML_USE_HIP) && defined(
|
|
3553
|
+
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3554
|
+
|
|
3555
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
2697
3556
|
|
|
2698
3557
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
2699
|
-
constexpr int blocks_per_iter =
|
|
3558
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
2700
3559
|
|
|
2701
3560
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
2702
3561
|
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
|
@@ -2745,10 +3604,10 @@ static __global__ void mul_mat_q(
|
|
|
2745
3604
|
|
|
2746
3605
|
__syncthreads();
|
|
2747
3606
|
#pragma unroll
|
|
2748
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
|
2749
|
-
const int j = j0 + threadIdx.y*
|
|
3607
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
|
3608
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2750
3609
|
|
|
2751
|
-
if (j0 + nwarps*
|
|
3610
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
|
2752
3611
|
break;
|
|
2753
3612
|
}
|
|
2754
3613
|
|
|
@@ -2757,7 +3616,7 @@ static __global__ void mul_mat_q(
|
|
|
2757
3616
|
__syncthreads();
|
|
2758
3617
|
}
|
|
2759
3618
|
|
|
2760
|
-
offset_y
|
|
3619
|
+
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
|
|
2761
3620
|
offset_dst += it*mmq_y;
|
|
2762
3621
|
|
|
2763
3622
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
@@ -2766,7 +3625,7 @@ static __global__ void mul_mat_q(
|
|
|
2766
3625
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
2767
3626
|
|
|
2768
3627
|
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
2769
|
-
mul_mat_q_process_tile<type, mmq_x,
|
|
3628
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
2770
3629
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
2771
3630
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
|
2772
3631
|
|
|
@@ -2812,10 +3671,10 @@ static __global__ void mul_mat_q(
|
|
|
2812
3671
|
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
|
|
2813
3672
|
__syncthreads();
|
|
2814
3673
|
#pragma unroll
|
|
2815
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
|
2816
|
-
const int j = j0 + threadIdx.y*
|
|
3674
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
|
3675
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2817
3676
|
|
|
2818
|
-
if (j0 + nwarps*
|
|
3677
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
|
2819
3678
|
break;
|
|
2820
3679
|
}
|
|
2821
3680
|
|
|
@@ -2824,7 +3683,7 @@ static __global__ void mul_mat_q(
|
|
|
2824
3683
|
__syncthreads();
|
|
2825
3684
|
}
|
|
2826
3685
|
|
|
2827
|
-
offset_y
|
|
3686
|
+
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
|
|
2828
3687
|
offset_dst += it*mmq_y;
|
|
2829
3688
|
|
|
2830
3689
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
@@ -2833,25 +3692,31 @@ static __global__ void mul_mat_q(
|
|
|
2833
3692
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
2834
3693
|
|
|
2835
3694
|
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
2836
|
-
mul_mat_q_process_tile<type, mmq_x,
|
|
3695
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
2837
3696
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
2838
3697
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
|
2839
3698
|
}
|
|
2840
3699
|
|
|
2841
3700
|
|
|
2842
|
-
template <ggml_type type, int mmq_x,
|
|
3701
|
+
template <ggml_type type, int mmq_x, bool need_check>
|
|
2843
3702
|
static __global__ void mul_mat_q_stream_k_fixup(
|
|
2844
3703
|
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
|
2845
3704
|
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
|
2846
|
-
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst
|
|
3705
|
+
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
|
|
3706
|
+
const int ncols_max) {
|
|
2847
3707
|
constexpr int mmq_y = get_mmq_y_device();
|
|
2848
3708
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
2849
|
-
constexpr int
|
|
3709
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3710
|
+
|
|
3711
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
2850
3712
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
2851
3713
|
|
|
2852
|
-
|
|
3714
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3715
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
3716
|
+
|
|
3717
|
+
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
|
2853
3718
|
|
|
2854
|
-
const int ntx = (
|
|
3719
|
+
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
|
|
2855
3720
|
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
|
2856
3721
|
|
|
2857
3722
|
const int bidx0 = blockIdx.x;
|
|
@@ -2893,10 +3758,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2893
3758
|
const int j = j0 + threadIdx.y;
|
|
2894
3759
|
|
|
2895
3760
|
#pragma unroll
|
|
2896
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
3761
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
2897
3762
|
const int i = i0 + threadIdx.x;
|
|
2898
3763
|
|
|
2899
|
-
sum[(j0/nwarps) * (mmq_y/
|
|
3764
|
+
sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
|
2900
3765
|
}
|
|
2901
3766
|
}
|
|
2902
3767
|
|
|
@@ -2937,14 +3802,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2937
3802
|
}
|
|
2938
3803
|
|
|
2939
3804
|
#pragma unroll
|
|
2940
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
3805
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
2941
3806
|
const int i = i0 + threadIdx.x;
|
|
2942
3807
|
|
|
2943
3808
|
if (need_check && i > i_max) {
|
|
2944
3809
|
continue;
|
|
2945
3810
|
}
|
|
2946
3811
|
|
|
2947
|
-
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/
|
|
3812
|
+
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
2948
3813
|
}
|
|
2949
3814
|
}
|
|
2950
3815
|
return;
|
|
@@ -2955,8 +3820,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2955
3820
|
const int col_high = expert_bounds[zt + 1];
|
|
2956
3821
|
const int col_diff = col_high - col_low;
|
|
2957
3822
|
|
|
2958
|
-
for (int j = threadIdx.y*
|
|
2959
|
-
ids_dst_shared[j] = ids_dst[col_low + j];
|
|
3823
|
+
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
|
|
3824
|
+
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
|
2960
3825
|
}
|
|
2961
3826
|
__syncthreads();
|
|
2962
3827
|
|
|
@@ -2975,14 +3840,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2975
3840
|
}
|
|
2976
3841
|
|
|
2977
3842
|
#pragma unroll
|
|
2978
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
3843
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
2979
3844
|
const int i = i0 + threadIdx.x;
|
|
2980
3845
|
|
|
2981
3846
|
if (need_check && i > i_max) {
|
|
2982
3847
|
continue;
|
|
2983
3848
|
}
|
|
2984
3849
|
|
|
2985
|
-
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/
|
|
3850
|
+
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
2986
3851
|
}
|
|
2987
3852
|
}
|
|
2988
3853
|
}
|
|
@@ -2992,17 +3857,17 @@ struct mmq_args {
|
|
|
2992
3857
|
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
|
|
2993
3858
|
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
|
|
2994
3859
|
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
|
|
2995
|
-
bool use_stream_k;
|
|
3860
|
+
bool use_stream_k; int64_t ncols_max;
|
|
2996
3861
|
};
|
|
2997
3862
|
|
|
2998
3863
|
template<ggml_type type>
|
|
2999
|
-
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
|
|
3864
|
+
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
|
|
3000
3865
|
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
|
3001
3866
|
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
|
3002
3867
|
const size_t nbs_ids = mmq_x*sizeof(int);
|
|
3003
|
-
const size_t nbs_x =
|
|
3004
|
-
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
|
3005
|
-
return nbs_ids + nbs_x + GGML_PAD(nbs_y,
|
|
3868
|
+
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
|
3869
|
+
const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
|
|
3870
|
+
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
|
3006
3871
|
}
|
|
3007
3872
|
|
|
3008
3873
|
template <ggml_type type, int mmq_x>
|
|
@@ -3010,23 +3875,19 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3010
3875
|
const int id = ggml_cuda_get_device();
|
|
3011
3876
|
const int cc = ggml_cuda_info().devices[id].cc;
|
|
3012
3877
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
|
3878
|
+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
|
3879
|
+
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
|
3013
3880
|
const int mmq_y = get_mmq_y_host(cc);
|
|
3014
3881
|
|
|
3015
|
-
const dim3 block_dims(
|
|
3882
|
+
const dim3 block_dims(warp_size, nwarps, 1);
|
|
3016
3883
|
|
|
3017
|
-
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
|
3884
|
+
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
|
|
3018
3885
|
|
|
3019
|
-
|
|
3020
|
-
|
|
3021
|
-
if (!shared_memory_limit_raised[id]) {
|
|
3022
|
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
|
3023
|
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
|
3024
|
-
shared_memory_limit_raised[id] = true;
|
|
3025
|
-
}
|
|
3026
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
3886
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
|
|
3887
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
|
|
3027
3888
|
|
|
3028
3889
|
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
|
3029
|
-
const int ntx = (args.
|
|
3890
|
+
const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
|
|
3030
3891
|
const int ntzw = args.nchannels_y * args.nsamples_y;
|
|
3031
3892
|
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
|
|
3032
3893
|
|
|
@@ -3038,18 +3899,20 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3038
3899
|
if (!args.use_stream_k) {
|
|
3039
3900
|
if (args.nrows_x % mmq_y == 0) {
|
|
3040
3901
|
constexpr bool need_check = false;
|
|
3041
|
-
mul_mat_q<type, mmq_x,
|
|
3902
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
|
3042
3903
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
|
3043
3904
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3044
3905
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3045
|
-
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst
|
|
3906
|
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
3907
|
+
args.ncols_max);
|
|
3046
3908
|
} else {
|
|
3047
3909
|
constexpr bool need_check = true;
|
|
3048
|
-
mul_mat_q<type, mmq_x,
|
|
3910
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
|
3049
3911
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
|
3050
3912
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3051
3913
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3052
|
-
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst
|
|
3914
|
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
3915
|
+
args.ncols_max);
|
|
3053
3916
|
}
|
|
3054
3917
|
return;
|
|
3055
3918
|
}
|
|
@@ -3065,44 +3928,48 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3065
3928
|
|
|
3066
3929
|
if (args.nrows_x % mmq_y == 0) {
|
|
3067
3930
|
constexpr bool need_check = false;
|
|
3068
|
-
|
|
3069
|
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3931
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3070
3932
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
|
3071
3933
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3072
3934
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3073
|
-
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst
|
|
3935
|
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
3936
|
+
args.ncols_max);
|
|
3074
3937
|
|
|
3075
3938
|
if (!fixup_needed) {
|
|
3076
3939
|
return;
|
|
3077
3940
|
}
|
|
3078
3941
|
|
|
3079
|
-
mul_mat_q_stream_k_fixup<type, mmq_x,
|
|
3942
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
|
3080
3943
|
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
|
3081
|
-
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst
|
|
3944
|
+
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
|
3945
|
+
args.ncols_max);
|
|
3082
3946
|
} else {
|
|
3083
3947
|
constexpr bool need_check = true;
|
|
3084
|
-
|
|
3085
|
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3948
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3086
3949
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
|
3087
3950
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3088
3951
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3089
|
-
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst
|
|
3952
|
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
|
3953
|
+
args.ncols_max);
|
|
3090
3954
|
|
|
3091
3955
|
if (!fixup_needed) {
|
|
3092
3956
|
return;
|
|
3093
3957
|
}
|
|
3094
3958
|
|
|
3095
|
-
mul_mat_q_stream_k_fixup<type, mmq_x,
|
|
3959
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
|
3096
3960
|
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
|
3097
|
-
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst
|
|
3961
|
+
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
|
3962
|
+
args.ncols_max);
|
|
3098
3963
|
}
|
|
3099
3964
|
}
|
|
3100
3965
|
|
|
3101
3966
|
template <ggml_type type>
|
|
3102
3967
|
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
|
3103
|
-
const int id
|
|
3104
|
-
const int cc
|
|
3105
|
-
const size_t smpbo
|
|
3968
|
+
const int id = ggml_cuda_get_device();
|
|
3969
|
+
const int cc = ggml_cuda_info().devices[id].cc;
|
|
3970
|
+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
3971
|
+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
|
3972
|
+
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
|
3106
3973
|
|
|
3107
3974
|
const int mmq_x_max = get_mmq_x_max_host(cc);
|
|
3108
3975
|
const int mmq_y = get_mmq_y_host(cc);
|
|
@@ -3113,11 +3980,11 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
|
|
3113
3980
|
for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
|
|
3114
3981
|
const int granularity = mmq_get_granularity_host(mmq_x, cc);
|
|
3115
3982
|
|
|
3116
|
-
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
|
|
3983
|
+
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
|
|
3117
3984
|
continue;
|
|
3118
3985
|
}
|
|
3119
3986
|
|
|
3120
|
-
const int ntiles_x = (args.
|
|
3987
|
+
const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
|
|
3121
3988
|
|
|
3122
3989
|
if (ntiles_x < ntiles_x_best) {
|
|
3123
3990
|
mmq_x_best = mmq_x;
|
|
@@ -3189,6 +4056,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
|
|
3189
4056
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
|
3190
4057
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
|
3191
4058
|
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
|
4059
|
+
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
|
3192
4060
|
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
|
3193
4061
|
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
|
3194
4062
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
|
@@ -3214,4 +4082,4 @@ void ggml_cuda_op_mul_mat_q(
|
|
|
3214
4082
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
|
3215
4083
|
const int64_t src1_padded_row_size, cudaStream_t stream);
|
|
3216
4084
|
|
|
3217
|
-
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
|
|
4085
|
+
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
|