whispercpp 1.3.2 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +59 -27
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +154 -35
- data/ext/sources/examples/addon.node/index.js +10 -5
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +29 -18
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +7 -4
- data/ext/sources/examples/command/command.cpp +58 -32
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +21 -17
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +193 -35
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +10 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
- data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
- data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
- data/ext/sources/examples/talk-llama/llama-context.h +68 -32
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
- data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
- data/ext/sources/examples/talk-llama/llama-model.h +87 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
- data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -17
- data/ext/sources/examples/talk-llama/llama.h +176 -151
- data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +106 -33
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +18 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +365 -21
- data/ext/sources/ggml/src/CMakeLists.txt +98 -25
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
- data/ext/sources/ggml/src/ggml-common.h +21 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
- data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
- data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
- data/ext/sources/ggml/src/ggml-impl.h +229 -175
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +117 -24
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +802 -142
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +32 -4
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +241 -215
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +57 -2
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +75 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/{tests → test}/test_params.rb +8 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +246 -191
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
58
58
|
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
59
59
|
case GGML_TYPE_Q8_0:
|
60
60
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
61
|
+
case GGML_TYPE_MXFP4:
|
62
|
+
return MMQ_Q8_1_DS_LAYOUT_D4;
|
61
63
|
case GGML_TYPE_Q2_K:
|
62
64
|
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
63
65
|
case GGML_TYPE_Q3_K:
|
@@ -90,7 +92,7 @@ struct tile_x_sizes {
|
|
90
92
|
};
|
91
93
|
|
92
94
|
static int get_mmq_x_max_host(const int cc) {
|
93
|
-
return
|
95
|
+
return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
|
94
96
|
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
95
97
|
#ifdef GGML_CUDA_FORCE_MMQ
|
96
98
|
128 : 64;
|
@@ -100,13 +102,13 @@ static int get_mmq_x_max_host(const int cc) {
|
|
100
102
|
}
|
101
103
|
|
102
104
|
static constexpr __device__ int get_mmq_x_max_device() {
|
103
|
-
#
|
105
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
104
106
|
return 128;
|
105
|
-
#else //
|
107
|
+
#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
106
108
|
|
107
|
-
#if defined(GGML_USE_HIP)
|
108
|
-
return
|
109
|
-
#else // defined(GGML_USE_HIP)
|
109
|
+
#if defined(GGML_USE_HIP)
|
110
|
+
return 64;
|
111
|
+
#else // defined(GGML_USE_HIP)
|
110
112
|
|
111
113
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
112
114
|
#ifdef GGML_CUDA_FORCE_MMQ
|
@@ -115,12 +117,11 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|
115
117
|
return MMQ_DP4A_MAX_BATCH_SIZE;
|
116
118
|
#endif // GGML_CUDA_FORCE_MMQ
|
117
119
|
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
118
|
-
|
119
120
|
return 64;
|
120
121
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
121
122
|
|
122
|
-
#endif // defined(GGML_USE_HIP)
|
123
|
-
#endif //
|
123
|
+
#endif // defined(GGML_USE_HIP)
|
124
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
124
125
|
}
|
125
126
|
|
126
127
|
static int get_mmq_y_host(const int cc) {
|
@@ -129,7 +130,7 @@ static int get_mmq_y_host(const int cc) {
|
|
129
130
|
}
|
130
131
|
|
131
132
|
static constexpr __device__ int get_mmq_y_device() {
|
132
|
-
#if defined(GGML_USE_HIP)
|
133
|
+
#if defined(GGML_USE_HIP)
|
133
134
|
#if defined(RDNA1)
|
134
135
|
return 64;
|
135
136
|
#else
|
@@ -141,19 +142,28 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
141
142
|
#else
|
142
143
|
return 64;
|
143
144
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
144
|
-
#endif // defined(GGML_USE_HIP)
|
145
|
+
#endif // defined(GGML_USE_HIP)
|
145
146
|
}
|
146
147
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
#define
|
155
|
-
|
156
|
-
#define
|
148
|
+
// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
|
149
|
+
// The K dimension of the tiles has either,
|
150
|
+
// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
|
151
|
+
// 32 bit elements for the quantized data (does not include scales).
|
152
|
+
// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
|
153
|
+
// The final tile size in K direction is padded to avoid shared memory bank conflicts,
|
154
|
+
// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
|
155
|
+
#define MMQ_TILE_NE_K 32
|
156
|
+
|
157
|
+
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
|
158
|
+
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
|
159
|
+
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
|
160
|
+
#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
|
161
|
+
#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
|
162
|
+
#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
|
163
|
+
#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
164
|
+
#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
165
|
+
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
166
|
+
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
157
167
|
|
158
168
|
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
159
169
|
switch (type) {
|
@@ -162,6 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
162
172
|
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
163
173
|
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
164
174
|
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
175
|
+
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
165
176
|
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
166
177
|
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
167
178
|
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
@@ -179,11 +190,11 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
179
190
|
}
|
180
191
|
}
|
181
192
|
|
182
|
-
#define MMQ_MMA_TILE_X_K_Q8_0 (2*
|
183
|
-
#define MMQ_MMA_TILE_X_K_Q8_1 (2*
|
184
|
-
#define MMQ_MMA_TILE_X_K_Q2_K (2*
|
185
|
-
#define MMQ_MMA_TILE_X_K_Q3_K (2*
|
186
|
-
#define MMQ_MMA_TILE_X_K_Q6_K (2*
|
193
|
+
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
194
|
+
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
195
|
+
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
196
|
+
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
197
|
+
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
187
198
|
|
188
199
|
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
189
200
|
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
@@ -198,6 +209,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
198
209
|
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
199
210
|
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
200
211
|
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
212
|
+
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
201
213
|
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
202
214
|
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
203
215
|
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
@@ -215,42 +227,76 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
215
227
|
}
|
216
228
|
}
|
217
229
|
|
218
|
-
|
230
|
+
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
|
231
|
+
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
|
219
232
|
|
220
233
|
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
221
|
-
|
234
|
+
if (amd_mfma_available(cc)) {
|
235
|
+
return mmq_x >= 128 ? 32 : 16;
|
236
|
+
} else if (turing_mma_available(cc) && mmq_x >= 48) {
|
237
|
+
return 16;
|
238
|
+
} else {
|
239
|
+
return 8;
|
240
|
+
}
|
222
241
|
}
|
223
242
|
|
224
|
-
#
|
243
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
244
|
+
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
245
|
+
return mmq_x >= 128 ? 32 : 16;
|
246
|
+
}
|
247
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
225
248
|
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
226
249
|
return mmq_x >= 48 ? 16 : 8;
|
227
250
|
}
|
228
251
|
#else
|
229
|
-
static constexpr __device__ int mmq_get_granularity_device(const int /*
|
252
|
+
static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
|
230
253
|
return 8;
|
231
254
|
}
|
232
|
-
#endif //
|
255
|
+
#endif // AMD_MFMA_AVAILABLE
|
256
|
+
|
257
|
+
#if defined(GGML_USE_HIP)
|
258
|
+
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
|
259
|
+
return amd_mfma_available(cc) ? 8 : 256/warp_size;
|
260
|
+
}
|
261
|
+
#else
|
262
|
+
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
|
263
|
+
return 256/warp_size;
|
264
|
+
}
|
265
|
+
#endif // (GGML_USE_HIP)
|
266
|
+
|
267
|
+
static constexpr __device__ int mmq_get_nwarps_device() {
|
268
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
269
|
+
return 8;
|
270
|
+
#else
|
271
|
+
return 256/ggml_cuda_get_physical_warp_size();
|
272
|
+
#endif // AMD_MFMA_AVAILABLE
|
273
|
+
}
|
233
274
|
|
234
275
|
// ------------------------------------------------------------
|
235
276
|
|
236
|
-
template <int mmq_y,
|
277
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
237
278
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
279
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
280
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
238
281
|
|
239
|
-
#
|
282
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
240
283
|
int * x_qs = (int *) x_tile;
|
241
|
-
float * x_df = (float *) (x_qs + 2*
|
284
|
+
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
|
242
285
|
#else
|
243
286
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
244
287
|
int * x_qs = (int *) x_tile;
|
245
288
|
float * x_df = (float *) (x_qs + txs.qs);
|
246
|
-
#endif //
|
289
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
247
290
|
|
248
|
-
|
249
|
-
|
291
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
|
292
|
+
constexpr int nrows = warp_size / threads_per_row;
|
293
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
294
|
+
const int kbx = txi / QI4_0;
|
295
|
+
const int kqsx = txi % QI4_0;
|
250
296
|
|
251
297
|
#pragma unroll
|
252
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
253
|
-
int i = i0 + threadIdx.y;
|
298
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
299
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
254
300
|
|
255
301
|
if (need_check) {
|
256
302
|
i = min(i, i_max);
|
@@ -259,20 +305,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
259
305
|
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
260
306
|
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
261
307
|
|
262
|
-
#
|
308
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
263
309
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
|
264
310
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
|
265
311
|
#else
|
266
|
-
x_qs[i*(
|
267
|
-
#endif //
|
312
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
313
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
268
314
|
}
|
269
315
|
|
270
|
-
|
316
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
|
317
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
271
318
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
272
319
|
|
273
320
|
#pragma unroll
|
274
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
275
|
-
int i = i0 + threadIdx.y *
|
321
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
322
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
276
323
|
|
277
324
|
if (need_check) {
|
278
325
|
i = min(i, i_max);
|
@@ -280,17 +327,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
280
327
|
|
281
328
|
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
|
282
329
|
|
283
|
-
#
|
284
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
330
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
331
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
285
332
|
#else
|
286
|
-
x_df[i*(
|
287
|
-
#endif //
|
333
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
334
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
288
335
|
}
|
289
336
|
}
|
290
337
|
|
291
|
-
template <int mmq_x, int mmq_y
|
338
|
+
template <int mmq_x, int mmq_y>
|
292
339
|
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
293
340
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
341
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
342
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
294
343
|
|
295
344
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
296
345
|
const int * x_qs = (const int *) x;
|
@@ -299,7 +348,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
299
348
|
const half2 * y_ds = (const half2 *) y;
|
300
349
|
|
301
350
|
// #pragma unroll
|
302
|
-
for (int k01 = 0; k01 <
|
351
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
|
303
352
|
const int k0 = k00 + k01;
|
304
353
|
|
305
354
|
#pragma unroll
|
@@ -307,7 +356,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
307
356
|
const int j = j0 + threadIdx.y;
|
308
357
|
|
309
358
|
#pragma unroll
|
310
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
359
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
311
360
|
const int i = i0 + threadIdx.x;
|
312
361
|
|
313
362
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
@@ -320,32 +369,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
320
369
|
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
|
321
370
|
}
|
322
371
|
|
323
|
-
sum[j0/nwarps*mmq_y/
|
324
|
-
(&x_qs[i*(
|
325
|
-
x_df[i*(
|
372
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
373
|
+
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
|
374
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
326
375
|
}
|
327
376
|
}
|
328
377
|
}
|
329
378
|
}
|
330
379
|
|
331
|
-
template <int mmq_y,
|
380
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
332
381
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
382
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
383
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
333
384
|
|
334
|
-
#
|
385
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
335
386
|
int * x_qs = (int *) x_tile;
|
336
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
387
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
337
388
|
#else
|
338
389
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
339
390
|
int * x_qs = (int *) x_tile;
|
340
391
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
341
|
-
#endif //
|
392
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
342
393
|
|
343
|
-
|
344
|
-
|
394
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
|
395
|
+
constexpr int nrows = warp_size / threads_per_row;
|
396
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
397
|
+
const int kbx = txi / QI4_1;
|
398
|
+
const int kqsx = txi % QI4_1;
|
345
399
|
|
346
400
|
#pragma unroll
|
347
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
348
|
-
int i = i0 + threadIdx.y;
|
401
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
402
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
349
403
|
|
350
404
|
if (need_check) {
|
351
405
|
i = min(i, i_max);
|
@@ -354,20 +408,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
354
408
|
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
355
409
|
const int qs0 = get_int_b4(bxi->qs, kqsx);
|
356
410
|
|
357
|
-
#
|
411
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
358
412
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
359
413
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
|
360
414
|
#else
|
361
|
-
x_qs[i*(
|
362
|
-
#endif //
|
415
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
416
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
363
417
|
}
|
364
418
|
|
365
|
-
|
419
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
|
420
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
366
421
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
367
422
|
|
368
423
|
#pragma unroll
|
369
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
370
|
-
int i = i0 + threadIdx.y *
|
424
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
425
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
371
426
|
|
372
427
|
if (need_check) {
|
373
428
|
i = min(i, i_max);
|
@@ -375,17 +430,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
375
430
|
|
376
431
|
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
|
377
432
|
|
378
|
-
#
|
379
|
-
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1
|
433
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
434
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
380
435
|
#else
|
381
|
-
x_dm[i*(
|
382
|
-
#endif //
|
436
|
+
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
|
437
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
383
438
|
}
|
384
439
|
}
|
385
440
|
|
386
|
-
template <int mmq_x, int mmq_y
|
441
|
+
template <int mmq_x, int mmq_y>
|
387
442
|
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
388
443
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
444
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
445
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
389
446
|
|
390
447
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
391
448
|
const int * x_qs = (const int *) x;
|
@@ -394,7 +451,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
394
451
|
const half2 * y_ds = (const half2 *) y;
|
395
452
|
|
396
453
|
// #pragma unroll
|
397
|
-
for (int k01 = 0; k01 <
|
454
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
|
398
455
|
const int k0 = k00 + k01;
|
399
456
|
|
400
457
|
#pragma unroll
|
@@ -402,7 +459,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
402
459
|
const int j = j0 + threadIdx.y;
|
403
460
|
|
404
461
|
#pragma unroll
|
405
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
462
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
406
463
|
const int i = i0 + threadIdx.x;
|
407
464
|
|
408
465
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
@@ -415,32 +472,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
415
472
|
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
|
416
473
|
}
|
417
474
|
|
418
|
-
sum[j0/nwarps*mmq_y/
|
419
|
-
(&x_qs[i*(
|
420
|
-
x_dm[i*(
|
475
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
476
|
+
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
|
477
|
+
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
421
478
|
}
|
422
479
|
}
|
423
480
|
}
|
424
481
|
}
|
425
482
|
|
426
|
-
template <int mmq_y,
|
483
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
427
484
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
485
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
486
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
428
487
|
|
429
|
-
#
|
488
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
430
489
|
int * x_qs = (int *) x_tile;
|
431
|
-
float * x_df = (float *) (x_qs +
|
490
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
432
491
|
#else
|
433
492
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
434
493
|
int * x_qs = (int *) x_tile;
|
435
494
|
float * x_df = (float *) (x_qs + txs.qs);
|
436
|
-
#endif //
|
495
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
437
496
|
|
438
|
-
|
439
|
-
|
497
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
|
498
|
+
constexpr int nrows = warp_size / threads_per_row;
|
499
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
500
|
+
const int kbx = txi / QI5_0;
|
501
|
+
const int kqsx = txi % QI5_0;
|
440
502
|
|
441
503
|
#pragma unroll
|
442
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
443
|
-
int i = i0 + threadIdx.y;
|
504
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
505
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
444
506
|
|
445
507
|
if (need_check) {
|
446
508
|
i = min(i, i_max);
|
@@ -449,7 +511,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
449
511
|
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
|
450
512
|
|
451
513
|
const int ql = get_int_b2(bxi->qs, kqsx);
|
452
|
-
const int qh = get_int_b2(bxi->qh, 0) >> (4 *
|
514
|
+
const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
|
453
515
|
|
454
516
|
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
455
517
|
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
@@ -465,21 +527,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
465
527
|
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
466
528
|
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
467
529
|
|
468
|
-
#
|
530
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
469
531
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
470
532
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
471
533
|
#else
|
472
|
-
x_qs[i*(2*
|
473
|
-
x_qs[i*(2*
|
474
|
-
#endif //
|
534
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
535
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
536
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
475
537
|
}
|
476
538
|
|
477
|
-
|
539
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
|
540
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
478
541
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
479
542
|
|
480
543
|
#pragma unroll
|
481
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
482
|
-
int i = i0 + threadIdx.y *
|
544
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
545
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
483
546
|
|
484
547
|
if (need_check) {
|
485
548
|
i = min(i, i_max);
|
@@ -487,32 +550,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
487
550
|
|
488
551
|
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
489
552
|
|
490
|
-
#
|
491
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
553
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
554
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
492
555
|
#else
|
493
|
-
x_df[i*(
|
494
|
-
#endif //
|
556
|
+
x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
557
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
495
558
|
}
|
496
559
|
}
|
497
560
|
|
498
|
-
template <int mmq_y,
|
561
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
499
562
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
563
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
564
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
500
565
|
|
501
|
-
#
|
566
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
502
567
|
int * x_qs = (int *) x_tile;
|
503
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
568
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
504
569
|
#else
|
505
570
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
506
571
|
int * x_qs = (int *) x_tile;
|
507
572
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
508
|
-
#endif //
|
573
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
509
574
|
|
510
|
-
|
511
|
-
|
575
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
|
576
|
+
constexpr int nrows = warp_size / threads_per_row;
|
577
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
578
|
+
const int kbx = txi / QI5_1;
|
579
|
+
const int kqsx = txi % QI5_1;
|
512
580
|
|
513
581
|
#pragma unroll
|
514
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
515
|
-
int i = i0 + threadIdx.y;
|
582
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
583
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
516
584
|
|
517
585
|
if (need_check) {
|
518
586
|
i = min(i, i_max);
|
@@ -521,7 +589,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
521
589
|
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
|
522
590
|
|
523
591
|
const int ql = get_int_b4(bxi->qs, kqsx);
|
524
|
-
const int qh = get_int_b4(bxi->qh, 0) >> (4 *
|
592
|
+
const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
|
525
593
|
|
526
594
|
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
527
595
|
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
@@ -535,21 +603,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
535
603
|
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
536
604
|
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
537
605
|
|
538
|
-
#
|
606
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
539
607
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
540
608
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
541
609
|
#else
|
542
|
-
x_qs[i*(2*
|
543
|
-
x_qs[i*(2*
|
544
|
-
#endif //
|
610
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
611
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
612
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
545
613
|
}
|
546
614
|
|
547
|
-
|
615
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
|
616
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
548
617
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
549
618
|
|
550
619
|
#pragma unroll
|
551
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
552
|
-
int i = i0 + threadIdx.y *
|
620
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
621
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
553
622
|
|
554
623
|
if (need_check) {
|
555
624
|
i = min(i, i_max);
|
@@ -557,32 +626,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
557
626
|
|
558
627
|
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
559
628
|
|
560
|
-
#
|
561
|
-
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1
|
629
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
630
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
562
631
|
#else
|
563
|
-
x_dm[i*(
|
564
|
-
#endif //
|
632
|
+
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
633
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
565
634
|
}
|
566
635
|
}
|
567
636
|
|
568
|
-
template <int mmq_y,
|
637
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
569
638
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
639
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
640
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
570
641
|
|
571
|
-
#
|
642
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
572
643
|
int * x_qs = (int *) x_tile;
|
573
|
-
float * x_df = (float *) (x_tile + 2*
|
644
|
+
float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
|
574
645
|
#else
|
575
646
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
576
647
|
int * x_qs = (int *) x_tile;
|
577
648
|
float * x_df = (float *) (x_qs + txs.qs);
|
578
|
-
#endif //
|
649
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
579
650
|
|
580
|
-
|
581
|
-
|
651
|
+
// MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
|
652
|
+
constexpr int threads_per_row = 32;
|
653
|
+
constexpr int nrows = warp_size / threads_per_row;
|
654
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
655
|
+
const int kbx = txi / QI8_0;
|
656
|
+
const int kqsx = txi % QI8_0;
|
582
657
|
|
583
658
|
#pragma unroll
|
584
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
585
|
-
int i = i0 + threadIdx.y;
|
659
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
660
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
586
661
|
|
587
662
|
if (need_check) {
|
588
663
|
i = min(i, i_max);
|
@@ -590,21 +665,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
590
665
|
|
591
666
|
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
592
667
|
|
593
|
-
#
|
594
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0
|
595
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 +
|
668
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
669
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
670
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
596
671
|
#else
|
597
|
-
x_qs[i*(2*
|
598
|
-
x_qs[i*(2*
|
599
|
-
#endif //
|
672
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
673
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
674
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
600
675
|
}
|
601
676
|
|
602
|
-
|
677
|
+
constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
|
678
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
603
679
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
604
680
|
|
605
681
|
#pragma unroll
|
606
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
607
|
-
int i = i0 + threadIdx.y *
|
682
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
683
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
608
684
|
|
609
685
|
if (need_check) {
|
610
686
|
i = min(i, i_max);
|
@@ -612,17 +688,84 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
612
688
|
|
613
689
|
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
614
690
|
|
615
|
-
#
|
616
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
691
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
692
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
617
693
|
#else
|
618
|
-
x_df[i*(2*
|
619
|
-
#endif //
|
694
|
+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
695
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
620
696
|
}
|
621
697
|
}
|
622
698
|
|
623
|
-
template <int
|
699
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
|
700
|
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
701
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
702
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
703
|
+
|
704
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
705
|
+
int * x_qs = (int *) x_tile;
|
706
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
707
|
+
#else
|
708
|
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
|
709
|
+
int * x_qs = (int *) x_tile;
|
710
|
+
float * x_df = (float *) (x_qs + txs.qs);
|
711
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
712
|
+
|
713
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
|
714
|
+
constexpr int nrows = warp_size / threads_per_row;
|
715
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
716
|
+
const int kbx = txi / QI_MXFP4;
|
717
|
+
const int kqsx = txi % QI_MXFP4;
|
718
|
+
|
719
|
+
#pragma unroll
|
720
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
721
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
722
|
+
|
723
|
+
if (need_check) {
|
724
|
+
i = min(i, i_max);
|
725
|
+
}
|
726
|
+
|
727
|
+
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
|
728
|
+
|
729
|
+
const int aux_q4 = get_int_b1(bxi->qs, kqsx);
|
730
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
731
|
+
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
|
732
|
+
|
733
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
734
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
|
735
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
|
736
|
+
#else
|
737
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
738
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
|
739
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
740
|
+
}
|
741
|
+
|
742
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
|
743
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
744
|
+
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
745
|
+
|
746
|
+
#pragma unroll
|
747
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
748
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
749
|
+
|
750
|
+
if (need_check) {
|
751
|
+
i = min(i, i_max);
|
752
|
+
}
|
753
|
+
|
754
|
+
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
|
755
|
+
|
756
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
757
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
758
|
+
#else
|
759
|
+
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
760
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
761
|
+
}
|
762
|
+
}
|
763
|
+
|
764
|
+
template <int mmq_x, int mmq_y>
|
624
765
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
625
766
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
767
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
768
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
626
769
|
|
627
770
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
628
771
|
const int * x_qs = (const int *) x;
|
@@ -631,7 +774,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
631
774
|
const float * y_df = (const float *) y;
|
632
775
|
|
633
776
|
// #pragma unroll
|
634
|
-
for (int k01 = 0; k01 <
|
777
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
|
635
778
|
const int k0 = k00 + k01;
|
636
779
|
|
637
780
|
#pragma unroll
|
@@ -639,21 +782,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
639
782
|
const int j = j0 + threadIdx.y;
|
640
783
|
|
641
784
|
#pragma unroll
|
642
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
785
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
643
786
|
const int i = i0 + threadIdx.x;
|
644
787
|
|
645
|
-
sum[j0/nwarps*mmq_y/
|
646
|
-
(&x_qs[i*(2*
|
647
|
-
x_df[i*(2*
|
788
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
789
|
+
(&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
|
790
|
+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
|
648
791
|
}
|
649
792
|
}
|
650
793
|
}
|
651
794
|
}
|
652
795
|
|
653
|
-
template <int mmq_x, int mmq_y,
|
796
|
+
template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
|
654
797
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
655
798
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
799
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
800
|
+
typedef tile<16, 8, int> tile_A;
|
801
|
+
typedef tile<16, 8, int> tile_B;
|
802
|
+
typedef tile<16, 16, int> tile_C;
|
803
|
+
|
804
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
805
|
+
constexpr int rows_per_warp = granularity;
|
806
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
807
|
+
|
808
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
809
|
+
|
810
|
+
const int * x_qs = (const int *) x;
|
811
|
+
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
|
812
|
+
const int * y_qs = (const int *) y + 4;
|
813
|
+
const float * y_df = (const float *) y;
|
814
|
+
const half2 * y_ds = (const half2 *) y;
|
656
815
|
|
816
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
817
|
+
|
818
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
819
|
+
const int k0 = k00 + k01;
|
820
|
+
|
821
|
+
tile_A A[ntx];
|
822
|
+
#pragma unroll
|
823
|
+
for (int n = 0; n < ntx; ++n) {
|
824
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
825
|
+
}
|
826
|
+
|
827
|
+
#pragma unroll
|
828
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
829
|
+
tile_B B;
|
830
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
831
|
+
|
832
|
+
float dB;
|
833
|
+
const int j = j0 + tile_C::get_j(0);
|
834
|
+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
835
|
+
dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
836
|
+
} else {
|
837
|
+
dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
838
|
+
}
|
839
|
+
|
840
|
+
#pragma unroll
|
841
|
+
for (int n = 0; n < ntx; ++n) {
|
842
|
+
tile_C C;
|
843
|
+
mma(C, A[n], B);
|
844
|
+
|
845
|
+
#pragma unroll
|
846
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
847
|
+
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
|
848
|
+
const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
849
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
|
850
|
+
}
|
851
|
+
}
|
852
|
+
}
|
853
|
+
}
|
854
|
+
#else
|
657
855
|
typedef tile<16, 8, int> tile_A;
|
658
856
|
typedef tile< 8, 8, int> tile_B;
|
659
857
|
typedef tile<16, 8, int> tile_C;
|
@@ -662,23 +860,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
662
860
|
constexpr int rows_per_warp = 2 * granularity;
|
663
861
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
664
862
|
|
665
|
-
y += (threadIdx.y % ntx) * (
|
863
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
666
864
|
|
667
865
|
const int * x_qs = (const int *) x;
|
668
|
-
const float * x_df = (const float *) x_qs + 2*
|
866
|
+
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
|
669
867
|
const int * y_qs = (const int *) y + 4;
|
670
868
|
const float * y_df = (const float *) y;
|
671
869
|
const half2 * y_ds = (const half2 *) y;
|
672
870
|
|
673
|
-
tile_A A[ntx][
|
674
|
-
float dA[ntx][tile_C::ne/2][
|
871
|
+
tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
|
872
|
+
float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
|
675
873
|
|
676
874
|
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
677
875
|
|
678
876
|
#pragma unroll
|
679
877
|
for (int n = 0; n < ntx; ++n) {
|
680
878
|
#pragma unroll
|
681
|
-
for (int k01 = 0; k01 <
|
879
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
682
880
|
const int k0 = k00 + k01;
|
683
881
|
|
684
882
|
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
@@ -689,7 +887,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
689
887
|
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
690
888
|
|
691
889
|
#pragma unroll
|
692
|
-
for (int k01 = 0; k01 <
|
890
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
693
891
|
const int k0 = k00 + k01;
|
694
892
|
|
695
893
|
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
@@ -700,7 +898,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
700
898
|
#pragma unroll
|
701
899
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
702
900
|
#pragma unroll
|
703
|
-
for (int k01 = 0; k01 <
|
901
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
704
902
|
tile_B B;
|
705
903
|
float dB[tile_C::ne/2];
|
706
904
|
|
@@ -729,11 +927,14 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
729
927
|
}
|
730
928
|
}
|
731
929
|
}
|
930
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
732
931
|
}
|
733
932
|
|
734
|
-
template <int mmq_x, int mmq_y
|
933
|
+
template <int mmq_x, int mmq_y>
|
735
934
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
736
935
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
936
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
937
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
737
938
|
|
738
939
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
739
940
|
const int * x_qs = (const int *) x;
|
@@ -742,7 +943,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
742
943
|
const half2 * y_ds = (const half2 *) y;
|
743
944
|
|
744
945
|
// #pragma unroll
|
745
|
-
for (int k01 = 0; k01 <
|
946
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
|
746
947
|
const int k0 = k00 + k01;
|
747
948
|
|
748
949
|
#pragma unroll
|
@@ -750,45 +951,95 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
750
951
|
const int j = j0 + threadIdx.y;
|
751
952
|
|
752
953
|
#pragma unroll
|
753
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
954
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
754
955
|
const int i = i0 + threadIdx.x;
|
755
956
|
|
756
|
-
sum[j0/nwarps*mmq_y/
|
757
|
-
(&x_qs[i*(2*
|
758
|
-
x_dm[i*(
|
957
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
958
|
+
(&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
959
|
+
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
759
960
|
}
|
760
961
|
}
|
761
962
|
}
|
762
963
|
}
|
763
964
|
|
764
|
-
template <int mmq_x, int mmq_y
|
965
|
+
template <int mmq_x, int mmq_y>
|
765
966
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
766
967
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
968
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
969
|
+
typedef tile<16, 8, int> tile_A;
|
970
|
+
typedef tile<16, 8, int> tile_B;
|
971
|
+
typedef tile<16, 16, int> tile_C;
|
767
972
|
|
768
|
-
|
769
|
-
|
770
|
-
|
973
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
974
|
+
constexpr int rows_per_warp = granularity;
|
975
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
976
|
+
|
977
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
978
|
+
|
979
|
+
const int * x_qs = (const int *) x;
|
980
|
+
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
|
981
|
+
const int * y_qs = (const int *) y + 4;
|
982
|
+
const half2 * y_dm = (const half2 *) y;
|
983
|
+
|
984
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
985
|
+
|
986
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
987
|
+
const int k0 = k00 + k01;
|
988
|
+
|
989
|
+
tile_A A[ntx];
|
990
|
+
#pragma unroll
|
991
|
+
for (int n = 0; n < ntx; ++n) {
|
992
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
993
|
+
}
|
994
|
+
|
995
|
+
#pragma unroll
|
996
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
997
|
+
tile_B B;
|
998
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
999
|
+
|
1000
|
+
const int j = j0 + tile_C::get_j(0);
|
1001
|
+
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
1002
|
+
|
1003
|
+
#pragma unroll
|
1004
|
+
for (int n = 0; n < ntx; ++n) {
|
1005
|
+
tile_C C;
|
1006
|
+
mma(C, A[n], B);
|
1007
|
+
|
1008
|
+
#pragma unroll
|
1009
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
1010
|
+
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
|
1011
|
+
float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
|
1012
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
|
1013
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
|
1014
|
+
}
|
1015
|
+
}
|
1016
|
+
}
|
1017
|
+
}
|
1018
|
+
#else
|
1019
|
+
typedef tile<16, 8, int> tile_A;
|
1020
|
+
typedef tile< 8, 8, int> tile_B;
|
1021
|
+
typedef tile<16, 8, int> tile_C;
|
771
1022
|
|
772
1023
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
773
1024
|
constexpr int rows_per_warp = 2 * granularity;
|
774
1025
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
775
1026
|
|
776
|
-
y += (threadIdx.y % ntx) * (
|
1027
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
777
1028
|
|
778
1029
|
const int * x_qs = (const int *) x;
|
779
|
-
const half2 * x_dm = (const half2 *) x_qs + 2*
|
1030
|
+
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
|
780
1031
|
const int * y_qs = (const int *) y + 4;
|
781
1032
|
const half2 * y_dm = (const half2 *) y;
|
782
1033
|
|
783
|
-
tile_A A[ntx][
|
784
|
-
float2 dmA[ntx][tile_C::ne/2][
|
1034
|
+
tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
|
1035
|
+
float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
|
785
1036
|
|
786
1037
|
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
787
1038
|
|
788
1039
|
#pragma unroll
|
789
1040
|
for (int n = 0; n < ntx; ++n) {
|
790
1041
|
#pragma unroll
|
791
|
-
for (int k01 = 0; k01 <
|
1042
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
792
1043
|
const int k0 = k00 + k01;
|
793
1044
|
|
794
1045
|
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
@@ -799,7 +1050,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
799
1050
|
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
800
1051
|
|
801
1052
|
#pragma unroll
|
802
|
-
for (int k01 = 0; k01 <
|
1053
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
803
1054
|
const int k0 = k00 + k01;
|
804
1055
|
|
805
1056
|
dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
|
@@ -810,7 +1061,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
810
1061
|
#pragma unroll
|
811
1062
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
812
1063
|
#pragma unroll
|
813
|
-
for (int k01 = 0; k01 <
|
1064
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
814
1065
|
tile_B B;
|
815
1066
|
float2 dsB[tile_C::ne/2];
|
816
1067
|
|
@@ -836,11 +1087,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
836
1087
|
}
|
837
1088
|
}
|
838
1089
|
}
|
1090
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
839
1091
|
}
|
840
1092
|
|
841
|
-
|
1093
|
+
// Used for Q3_K, IQ2_S, and IQ2_XS
|
1094
|
+
template <int mmq_x, int mmq_y>
|
842
1095
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
843
1096
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
1097
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
1098
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
844
1099
|
|
845
1100
|
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
846
1101
|
const int * x_qs = (const int *) x;
|
@@ -849,7 +1104,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
849
1104
|
const float * y_df = (const float *) y;
|
850
1105
|
|
851
1106
|
// #pragma unroll
|
852
|
-
for (int k01 = 0; k01 <
|
1107
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
853
1108
|
const int k0 = k00 + k01;
|
854
1109
|
|
855
1110
|
#pragma unroll
|
@@ -857,23 +1112,73 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
857
1112
|
const int j = j0 + threadIdx.y;
|
858
1113
|
|
859
1114
|
#pragma unroll
|
860
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
1115
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
861
1116
|
const int i = i0 + threadIdx.x;
|
862
1117
|
|
863
|
-
sum[j0/nwarps*mmq_y/
|
864
|
-
&x_qs[i*(2*
|
1118
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
|
1119
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
|
865
1120
|
&y_qs[j*MMQ_TILE_Y_K + k01],
|
866
|
-
&x_df[i*(2*
|
1121
|
+
&x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
|
867
1122
|
y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
868
1123
|
}
|
869
1124
|
}
|
870
1125
|
}
|
871
1126
|
}
|
872
1127
|
|
873
|
-
|
1128
|
+
// Used for Q3_K, IQ2_S, and IQ2_XS:
|
1129
|
+
template <int mmq_x, int mmq_y>
|
874
1130
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
875
1131
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
876
|
-
#
|
1132
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
1133
|
+
typedef tile<16, 8, int> tile_A;
|
1134
|
+
typedef tile<16, 8, int> tile_B;
|
1135
|
+
typedef tile<16, 16, int> tile_C;
|
1136
|
+
typedef tile<64, 2, int> tile_load;
|
1137
|
+
|
1138
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
1139
|
+
constexpr int rows_per_warp = granularity;
|
1140
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
1141
|
+
|
1142
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
1143
|
+
|
1144
|
+
const int * x_qs = (const int *) x;
|
1145
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
1146
|
+
const int * y_qs = (const int *) y + 4;
|
1147
|
+
const float * y_df = (const float *) y;
|
1148
|
+
|
1149
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
1150
|
+
|
1151
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
1152
|
+
const int k0 = k00 + k01;
|
1153
|
+
|
1154
|
+
tile_A A[ntx];
|
1155
|
+
#pragma unroll
|
1156
|
+
for (int n = 0; n < ntx; ++n) {
|
1157
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
1158
|
+
}
|
1159
|
+
|
1160
|
+
#pragma unroll
|
1161
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
1162
|
+
tile_B B[1];
|
1163
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
1164
|
+
|
1165
|
+
const int j = j0 + tile_C::get_j(0);
|
1166
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
1167
|
+
|
1168
|
+
#pragma unroll
|
1169
|
+
for (int n = 0; n < ntx; ++n) {
|
1170
|
+
tile_C C;
|
1171
|
+
mma(C, A[n], B[0]);
|
1172
|
+
|
1173
|
+
#pragma unroll
|
1174
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
1175
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
1176
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
|
1177
|
+
}
|
1178
|
+
}
|
1179
|
+
}
|
1180
|
+
}
|
1181
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
877
1182
|
|
878
1183
|
typedef tile<16, 4, int> tile_A;
|
879
1184
|
typedef tile<16, 8, int> tile_A_8;
|
@@ -884,10 +1189,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
884
1189
|
constexpr int rows_per_warp = 2 * granularity;
|
885
1190
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
886
1191
|
|
887
|
-
y += (threadIdx.y % ntx) * (
|
1192
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
888
1193
|
|
889
1194
|
const int * x_qs = (const int *) x;
|
890
|
-
const float * x_df = (const float *) x_qs +
|
1195
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
891
1196
|
const int * y_qs = (const int *) y + 4;
|
892
1197
|
const float * y_df = (const float *) y;
|
893
1198
|
|
@@ -899,7 +1204,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
899
1204
|
#pragma unroll
|
900
1205
|
for (int n = 0; n < ntx; ++n) {
|
901
1206
|
#pragma unroll
|
902
|
-
for (int k01 = 0; k01 <
|
1207
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
903
1208
|
const int k0 = k00 + k01;
|
904
1209
|
|
905
1210
|
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
@@ -910,7 +1215,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
910
1215
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
911
1216
|
|
912
1217
|
#pragma unroll
|
913
|
-
for (int k01 = 0; k01 <
|
1218
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
914
1219
|
const int k0 = k00 + k01;
|
915
1220
|
|
916
1221
|
dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
|
@@ -921,7 +1226,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
921
1226
|
#pragma unroll
|
922
1227
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
923
1228
|
#pragma unroll
|
924
|
-
for (int k01 = 0; k01 <
|
1229
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
925
1230
|
tile_B B[2];
|
926
1231
|
float dB[tile_C::ne/2];
|
927
1232
|
|
@@ -950,28 +1255,31 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
950
1255
|
}
|
951
1256
|
}
|
952
1257
|
#else
|
953
|
-
|
1258
|
+
GGML_UNUSED_VARS(x, y, sum, k00);
|
954
1259
|
NO_DEVICE_CODE;
|
955
|
-
#endif //
|
1260
|
+
#endif // AMD_MFMA_AVAILABLE
|
956
1261
|
}
|
957
1262
|
|
958
|
-
template <int mmq_y,
|
1263
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
959
1264
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
1265
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
960
1266
|
|
961
|
-
#
|
1267
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
962
1268
|
int * x_qs = (int *) x_tile;
|
963
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
1269
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
964
1270
|
#else
|
965
1271
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
966
1272
|
int * x_qs = (int *) x_tile;
|
967
1273
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
968
|
-
#endif //
|
1274
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
969
1275
|
|
970
|
-
|
1276
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
|
1277
|
+
constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
|
1278
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
971
1279
|
|
972
1280
|
#pragma unroll
|
973
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
974
|
-
int i = i0 + threadIdx.y*
|
1281
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
1282
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
975
1283
|
|
976
1284
|
if (need_check) {
|
977
1285
|
i = min(i, i_max);
|
@@ -987,11 +1295,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
987
1295
|
|
988
1296
|
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
989
1297
|
|
990
|
-
#
|
1298
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
991
1299
|
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
992
1300
|
#else
|
993
|
-
x_qs[i*(2*
|
994
|
-
#endif //
|
1301
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
1302
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
995
1303
|
}
|
996
1304
|
|
997
1305
|
const int sc_m = bxi->scales[kqsx];
|
@@ -1002,17 +1310,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1002
1310
|
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
1003
1311
|
#endif // FAST_FP16_AVAILABLE
|
1004
1312
|
|
1005
|
-
#
|
1313
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1006
1314
|
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
1007
1315
|
#else
|
1008
|
-
x_dm[i*(
|
1009
|
-
#endif //
|
1316
|
+
x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
|
1317
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1010
1318
|
}
|
1011
1319
|
}
|
1012
1320
|
|
1013
|
-
template <int mmq_x, int mmq_y
|
1321
|
+
template <int mmq_x, int mmq_y>
|
1014
1322
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
1015
1323
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
1324
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
1325
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1016
1326
|
|
1017
1327
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
1018
1328
|
const int * x_qs = (const int *) x;
|
@@ -1029,7 +1339,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
1029
1339
|
}
|
1030
1340
|
|
1031
1341
|
#pragma unroll
|
1032
|
-
for (int k01 = 0; k01 <
|
1342
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
1033
1343
|
const int k0 = k00 + k01;
|
1034
1344
|
|
1035
1345
|
#pragma unroll
|
@@ -1037,13 +1347,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
1037
1347
|
const int j = j0 + threadIdx.y;
|
1038
1348
|
|
1039
1349
|
#pragma unroll
|
1040
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
1350
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
1041
1351
|
const int i = i0 + threadIdx.x;
|
1042
1352
|
|
1043
1353
|
constexpr int ns = 2;
|
1044
|
-
sum[j0/nwarps*mmq_y/
|
1045
|
-
&x_qs[i*(2*
|
1046
|
-
&x_dm[i*(
|
1354
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
1355
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
1356
|
+
&x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
1047
1357
|
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
1048
1358
|
}
|
1049
1359
|
}
|
@@ -1052,7 +1362,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
1052
1362
|
// Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
|
1053
1363
|
// As a workaround 2 separate loops are used instead.
|
1054
1364
|
#pragma unroll
|
1055
|
-
for (int k01 =
|
1365
|
+
for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
1056
1366
|
const int k0 = k00 + k01;
|
1057
1367
|
|
1058
1368
|
#pragma unroll
|
@@ -1060,23 +1370,89 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
1060
1370
|
const int j = j0 + threadIdx.y;
|
1061
1371
|
|
1062
1372
|
#pragma unroll
|
1063
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
1373
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
1064
1374
|
const int i = i0 + threadIdx.x;
|
1065
1375
|
|
1066
1376
|
constexpr int ns = 1;
|
1067
|
-
sum[j0/nwarps*mmq_y/
|
1068
|
-
&x_qs[i*(2*
|
1069
|
-
&x_dm[i*(
|
1377
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
1378
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
1379
|
+
&x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
1070
1380
|
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
1071
1381
|
}
|
1072
1382
|
}
|
1073
1383
|
}
|
1074
1384
|
}
|
1075
1385
|
|
1076
|
-
template <int mmq_x, int mmq_y
|
1386
|
+
template <int mmq_x, int mmq_y>
|
1077
1387
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
1078
1388
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
1079
|
-
#
|
1389
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
1390
|
+
typedef tile<16, 8, int> tile_A;
|
1391
|
+
typedef tile<16, 8, int> tile_B;
|
1392
|
+
typedef tile<16, 16, int> tile_C;
|
1393
|
+
typedef tile<64, 2, int> tile_load;
|
1394
|
+
|
1395
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
1396
|
+
constexpr int rows_per_warp = granularity;
|
1397
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
1398
|
+
|
1399
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
1400
|
+
|
1401
|
+
const int * x_qs = (const int *) x;
|
1402
|
+
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
1403
|
+
const int * y_qs = (const int *) y + 4;
|
1404
|
+
const half2 * y_ds = (const half2 *) y;
|
1405
|
+
|
1406
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
1407
|
+
|
1408
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
1409
|
+
const int k0 = k00 + k01;
|
1410
|
+
|
1411
|
+
tile_A A[ntx];
|
1412
|
+
#pragma unroll
|
1413
|
+
for (int n = 0; n < ntx; ++n) {
|
1414
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
1415
|
+
}
|
1416
|
+
|
1417
|
+
#pragma unroll
|
1418
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
1419
|
+
tile_B B[1];
|
1420
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
1421
|
+
|
1422
|
+
const int j = j0 + tile_C::get_j(0);
|
1423
|
+
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
|
1424
|
+
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
|
1425
|
+
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
|
1426
|
+
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
|
1427
|
+
|
1428
|
+
tile_C Cm;
|
1429
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
1430
|
+
tile_A A1;
|
1431
|
+
A1.x[0] = 0x01010101;
|
1432
|
+
A1.x[1] = 0x01010101;
|
1433
|
+
mma(Cm, A1, B[0]);
|
1434
|
+
}
|
1435
|
+
|
1436
|
+
#pragma unroll
|
1437
|
+
for (int n = 0; n < ntx; ++n) {
|
1438
|
+
tile_C Cd;
|
1439
|
+
mma(Cd, A[n], B[0]);
|
1440
|
+
|
1441
|
+
#pragma unroll
|
1442
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
1443
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
1444
|
+
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
|
1445
|
+
float tmp = Cd.x[l]*dm.x;
|
1446
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
1447
|
+
tmp -= Cm.x[l]*dm.y;
|
1448
|
+
}
|
1449
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
|
1450
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
|
1451
|
+
}
|
1452
|
+
}
|
1453
|
+
}
|
1454
|
+
}
|
1455
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
1080
1456
|
|
1081
1457
|
typedef tile<16, 4, int> tile_A;
|
1082
1458
|
typedef tile<16, 8, int> tile_A_8;
|
@@ -1087,10 +1463,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1087
1463
|
constexpr int rows_per_warp = 2 * granularity;
|
1088
1464
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
1089
1465
|
|
1090
|
-
y += (threadIdx.y % ntx) * (
|
1466
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
1091
1467
|
|
1092
1468
|
const int * x_qs = (const int *) x;
|
1093
|
-
const half2 * x_dm = (const half2 *) x_qs +
|
1469
|
+
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
1094
1470
|
const int * y_qs = (const int *) y + 4;
|
1095
1471
|
const half2 * y_ds = (const half2 *) y;
|
1096
1472
|
|
@@ -1103,7 +1479,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1103
1479
|
#pragma unroll
|
1104
1480
|
for (int n = 0; n < ntx; ++n) {
|
1105
1481
|
#pragma unroll
|
1106
|
-
for (int k01 = 0; k01 <
|
1482
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
1107
1483
|
const int k0 = k00 + k01;
|
1108
1484
|
|
1109
1485
|
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
@@ -1117,7 +1493,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1117
1493
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
1118
1494
|
|
1119
1495
|
#pragma unroll
|
1120
|
-
for (int k01 = 0; k01 <
|
1496
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
|
1121
1497
|
const int k0 = k00 + k01;
|
1122
1498
|
|
1123
1499
|
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
|
@@ -1140,7 +1516,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1140
1516
|
}
|
1141
1517
|
|
1142
1518
|
#pragma unroll
|
1143
|
-
for (int k01 = 0; k01 <
|
1519
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
1144
1520
|
tile_B B[2];
|
1145
1521
|
|
1146
1522
|
// Here load_generic is faster than load_ldmatrix.
|
@@ -1148,7 +1524,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1148
1524
|
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
1149
1525
|
|
1150
1526
|
tile_C Cm[2];
|
1151
|
-
if (k01 >=
|
1527
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
1152
1528
|
tile_A A1;
|
1153
1529
|
A1.x[0] = 0x01010101;
|
1154
1530
|
A1.x[1] = 0x01010101;
|
@@ -1166,16 +1542,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1166
1542
|
#pragma unroll
|
1167
1543
|
for (int l = 0; l < tile_C::ne; ++l) {
|
1168
1544
|
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
1169
|
-
if (k01 >=
|
1545
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
1170
1546
|
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
1171
1547
|
}
|
1172
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 <
|
1548
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
|
1173
1549
|
}
|
1174
1550
|
}
|
1175
1551
|
}
|
1176
1552
|
|
1177
1553
|
#pragma unroll
|
1178
|
-
for (int k01 = 0; k01 <
|
1554
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
|
1179
1555
|
float2 sB[tile_C::ne/2];
|
1180
1556
|
|
1181
1557
|
#pragma unroll
|
@@ -1196,29 +1572,33 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1196
1572
|
}
|
1197
1573
|
}
|
1198
1574
|
#else
|
1199
|
-
|
1575
|
+
GGML_UNUSED_VARS(x, y, sum, k00);
|
1200
1576
|
NO_DEVICE_CODE;
|
1201
|
-
#endif //
|
1577
|
+
#endif // AMD_MFMA_AVAILABLE
|
1202
1578
|
}
|
1203
1579
|
|
1204
|
-
template <int mmq_y,
|
1580
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
1205
1581
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
1582
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
1583
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1206
1584
|
|
1207
|
-
#
|
1585
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1208
1586
|
int * x_qs = (int *) x_tile;
|
1209
|
-
float * x_df = (float *) (x_qs +
|
1587
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
1210
1588
|
#else
|
1211
1589
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
1212
1590
|
int * x_qs = (int *) x_tile;
|
1213
1591
|
float * x_df = (float *) (x_qs + txs.qs);
|
1214
1592
|
int * x_sc = (int *) (x_df + txs.dm);
|
1215
|
-
#endif //
|
1593
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1216
1594
|
|
1217
|
-
|
1595
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
|
1596
|
+
constexpr int nrows = warp_size / threads_per_row;
|
1597
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
1218
1598
|
|
1219
1599
|
#pragma unroll
|
1220
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
1221
|
-
int i = i0 + threadIdx.y
|
1600
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
1601
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
1222
1602
|
|
1223
1603
|
if (need_check) {
|
1224
1604
|
i = min(i, i_max);
|
@@ -1238,17 +1618,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1238
1618
|
|
1239
1619
|
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
1240
1620
|
|
1241
|
-
#
|
1621
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1242
1622
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
1243
1623
|
#else
|
1244
|
-
x_qs[i*(2*
|
1245
|
-
#endif //
|
1624
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
1625
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1246
1626
|
}
|
1247
1627
|
}
|
1248
1628
|
|
1629
|
+
constexpr int rows_per_warp = warp_size / 4;
|
1249
1630
|
#pragma unroll
|
1250
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
1251
|
-
int i = i0 + threadIdx.y*
|
1631
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
1632
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
|
1252
1633
|
|
1253
1634
|
if (need_check) {
|
1254
1635
|
i = min(i, i_max);
|
@@ -1256,7 +1637,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1256
1637
|
|
1257
1638
|
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
|
1258
1639
|
|
1259
|
-
const int ksc = threadIdx.x %
|
1640
|
+
const int ksc = threadIdx.x % 4;
|
1260
1641
|
|
1261
1642
|
const int ksc_low = ksc % (QI3_K/8);
|
1262
1643
|
const int shift_low = 4 * (ksc / (QI3_K/8));
|
@@ -1268,23 +1649,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1268
1649
|
|
1269
1650
|
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
1270
1651
|
|
1271
|
-
#
|
1652
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1272
1653
|
const int8_t * sc8 = (const int8_t *) ≻
|
1273
1654
|
const float d = bxi->d;
|
1274
1655
|
|
1275
1656
|
#pragma unroll
|
1276
1657
|
for (int l = 0; l < int(sizeof(int)); ++l) {
|
1277
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*
|
1658
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
|
1278
1659
|
}
|
1279
1660
|
#else
|
1280
|
-
x_sc[i*(
|
1281
|
-
#endif //
|
1661
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
|
1662
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1282
1663
|
}
|
1283
1664
|
|
1284
|
-
#
|
1665
|
+
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
1285
1666
|
#pragma unroll
|
1286
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
1287
|
-
int i = (i0 + threadIdx.y*
|
1667
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
1668
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
1288
1669
|
|
1289
1670
|
if (need_check) {
|
1290
1671
|
i = min(i, i_max);
|
@@ -1294,12 +1675,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1294
1675
|
|
1295
1676
|
x_df[i] = bxi->d;
|
1296
1677
|
}
|
1297
|
-
#endif //
|
1678
|
+
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
1298
1679
|
}
|
1299
1680
|
|
1300
|
-
template <int mmq_x, int mmq_y
|
1681
|
+
template <int mmq_x, int mmq_y>
|
1301
1682
|
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
1302
1683
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
1684
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
1685
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1303
1686
|
|
1304
1687
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
1305
1688
|
const int * x_qs = (const int *) x;
|
@@ -1309,7 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
1309
1692
|
const float * y_df = (const float *) y;
|
1310
1693
|
|
1311
1694
|
// #pragma unroll
|
1312
|
-
for (int k01 = 0; k01 <
|
1695
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
1313
1696
|
const int k0 = k00 + k01;
|
1314
1697
|
|
1315
1698
|
#pragma unroll
|
@@ -1317,13 +1700,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
1317
1700
|
const int j = j0 + threadIdx.y;
|
1318
1701
|
|
1319
1702
|
#pragma unroll
|
1320
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
1703
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
1321
1704
|
const int i = i0 + threadIdx.x;
|
1322
1705
|
|
1323
|
-
const int8_t * scales = ((const int8_t *) (x_sc + i*(
|
1706
|
+
const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
|
1324
1707
|
|
1325
|
-
sum[j0/nwarps*mmq_y/
|
1326
|
-
&x_qs[i*(2*
|
1708
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
|
1709
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
|
1327
1710
|
x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
1328
1711
|
}
|
1329
1712
|
}
|
@@ -1340,72 +1723,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
|
|
1340
1723
|
((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
|
1341
1724
|
}
|
1342
1725
|
|
1343
|
-
template <int mmq_y,
|
1726
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
1344
1727
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
1728
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
1729
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1345
1730
|
|
1346
|
-
#
|
1731
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1347
1732
|
int * x_qs = (int *) x_tile;
|
1348
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
1733
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
1349
1734
|
#else
|
1350
1735
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
1351
1736
|
int * x_qs = (int *) x_tile;
|
1352
1737
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
1353
1738
|
int * x_sc = (int *) (x_dm + txs.dm);
|
1354
|
-
#endif //
|
1739
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1740
|
+
|
1741
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
|
1742
|
+
constexpr int nrows = warp_size / threads_per_row;
|
1743
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
1355
1744
|
|
1356
1745
|
#pragma unroll
|
1357
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
1358
|
-
int i = i0 + threadIdx.y;
|
1746
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
1747
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
1359
1748
|
|
1360
1749
|
if (need_check) {
|
1361
1750
|
i = min(i, i_max);
|
1362
1751
|
}
|
1363
1752
|
|
1364
1753
|
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
1365
|
-
const int qs0 = get_int_b4(bxi->qs,
|
1754
|
+
const int qs0 = get_int_b4(bxi->qs, txi);
|
1366
1755
|
|
1367
|
-
#
|
1368
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(
|
1369
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(
|
1756
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1757
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
1758
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
|
1370
1759
|
#else
|
1371
|
-
x_qs[i*(
|
1372
|
-
#endif //
|
1760
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
1761
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1373
1762
|
}
|
1374
1763
|
|
1375
|
-
#
|
1376
|
-
|
1764
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1765
|
+
constexpr int rows_per_warp = warp_size / 2;
|
1377
1766
|
#pragma unroll
|
1378
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1767
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
1768
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
1769
|
+
// Need if on AMD instead of % because warp_size == 64
|
1770
|
+
// This causes double work and throughput loss (MI300X)
|
1771
|
+
// H100 loses about 100 t/s with 'if' condition over '%'
|
1772
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
|
1773
|
+
if (i < mmq_y) {
|
1774
|
+
#else
|
1775
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
|
1776
|
+
{
|
1777
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
1778
|
+
if (need_check) {
|
1779
|
+
i = min(i, i_max);
|
1780
|
+
}
|
1384
1781
|
|
1385
|
-
|
1782
|
+
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
1386
1783
|
|
1387
|
-
|
1388
|
-
|
1784
|
+
const int * scales = (const int *) bxi->scales;
|
1785
|
+
const int ksc = threadIdx.x % 2;
|
1389
1786
|
|
1390
|
-
|
1391
|
-
|
1787
|
+
const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
|
1788
|
+
const int m32 = unpack_scales_q45_K(scales, ksc + 2);
|
1392
1789
|
|
1393
|
-
|
1394
|
-
|
1790
|
+
const uint8_t * sc8 = (const uint8_t *) &sc32;
|
1791
|
+
const uint8_t * m8 = (const uint8_t *) &m32;
|
1395
1792
|
|
1396
|
-
|
1793
|
+
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
|
1397
1794
|
|
1398
|
-
#pragma unroll
|
1399
|
-
|
1400
|
-
|
1795
|
+
#pragma unroll
|
1796
|
+
for (int l = 0; l < sizeof(int); ++l) {
|
1797
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
|
1798
|
+
}
|
1401
1799
|
}
|
1402
1800
|
}
|
1403
|
-
|
1404
1801
|
#else
|
1405
|
-
|
1406
1802
|
#pragma unroll
|
1407
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
1408
|
-
int i = (i0 + threadIdx.y*
|
1803
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
1804
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
1409
1805
|
|
1410
1806
|
if (need_check) {
|
1411
1807
|
i = min(i, i_max);
|
@@ -1415,30 +1811,32 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1415
1811
|
|
1416
1812
|
x_dm[i] = bxi->dm;
|
1417
1813
|
}
|
1418
|
-
|
1814
|
+
constexpr int rows_per_warp = warp_size / 4;
|
1419
1815
|
#pragma unroll
|
1420
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
1421
|
-
int i = (i0 + threadIdx.y
|
1816
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
1817
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
1422
1818
|
|
1423
1819
|
if (need_check) {
|
1424
1820
|
i = min(i, i_max);
|
1425
1821
|
}
|
1426
1822
|
|
1427
|
-
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (
|
1823
|
+
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
|
1428
1824
|
|
1429
1825
|
const int * scales = (const int *) bxi->scales;
|
1430
1826
|
|
1431
|
-
const int ksc = threadIdx.x % (
|
1827
|
+
const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
|
1432
1828
|
const int scales8 = unpack_scales_q45_K(scales, ksc);
|
1433
1829
|
|
1434
|
-
x_sc[i*(
|
1830
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
1435
1831
|
}
|
1436
|
-
#endif //
|
1832
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1437
1833
|
}
|
1438
1834
|
|
1439
|
-
template <int mmq_x, int mmq_y
|
1835
|
+
template <int mmq_x, int mmq_y>
|
1440
1836
|
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
1441
1837
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
1838
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
1839
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1442
1840
|
|
1443
1841
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
1444
1842
|
const int * x_qs = (const int *) x;
|
@@ -1448,7 +1846,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
1448
1846
|
const half2 * y_ds = (const half2 *) y;
|
1449
1847
|
|
1450
1848
|
// #pragma unroll
|
1451
|
-
for (int k01 = 0; k01 <
|
1849
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
|
1452
1850
|
const int k0 = k00 + k01;
|
1453
1851
|
|
1454
1852
|
#pragma unroll
|
@@ -1456,97 +1854,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
1456
1854
|
const int j = j0 + threadIdx.y;
|
1457
1855
|
|
1458
1856
|
#pragma unroll
|
1459
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
1857
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
1460
1858
|
const int i = i0 + threadIdx.x;
|
1461
1859
|
|
1462
|
-
const uint8_t * sc = (const uint8_t *) &x_sc[i * (
|
1860
|
+
const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
|
1463
1861
|
|
1464
|
-
sum[j0/nwarps*mmq_y/
|
1465
|
-
&x_qs[i*(
|
1862
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
|
1863
|
+
&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
|
1466
1864
|
x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
1467
1865
|
}
|
1468
1866
|
}
|
1469
1867
|
}
|
1470
1868
|
}
|
1471
1869
|
|
1472
|
-
template <int mmq_y,
|
1870
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
1473
1871
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
1872
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
1873
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1474
1874
|
|
1475
|
-
#
|
1875
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1476
1876
|
int * x_qs = (int *) x_tile;
|
1477
|
-
half2 * x_dm = (half2 *) (x_qs +
|
1877
|
+
half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
1478
1878
|
#else
|
1479
1879
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
1480
1880
|
int * x_qs = (int *) x_tile;
|
1481
1881
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
1482
1882
|
int * x_sc = (int *) (x_dm + txs.dm);
|
1483
|
-
#endif //
|
1883
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1884
|
+
|
1885
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
|
1886
|
+
constexpr int nrows = warp_size / threads_per_row;
|
1887
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
1484
1888
|
|
1485
1889
|
#pragma unroll
|
1486
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
1487
|
-
int i = i0 + threadIdx.y;
|
1890
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
1891
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
1488
1892
|
|
1489
1893
|
if (need_check) {
|
1490
1894
|
i = min(i, i_max);
|
1491
1895
|
}
|
1492
1896
|
|
1493
1897
|
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
|
1494
|
-
const int ky = QR5_K*
|
1898
|
+
const int ky = QR5_K*txi;
|
1495
1899
|
|
1496
|
-
const int ql = get_int_b4(bxi->qs,
|
1900
|
+
const int ql = get_int_b4(bxi->qs, txi);
|
1497
1901
|
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
1498
1902
|
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
1499
1903
|
|
1500
|
-
const int qh = get_int_b4(bxi->qh,
|
1501
|
-
const int qh0 = ((qh >> (2 * (
|
1502
|
-
const int qh1 = ((qh >> (2 * (
|
1904
|
+
const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
|
1905
|
+
const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
|
1906
|
+
const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
|
1503
1907
|
|
1504
|
-
const int kq0 = ky - ky % (QI5_K/2) +
|
1505
|
-
const int kq1 = ky - ky % (QI5_K/2) +
|
1908
|
+
const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
|
1909
|
+
const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
|
1506
1910
|
|
1507
|
-
#
|
1911
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1508
1912
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
|
1509
1913
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
|
1510
1914
|
#else
|
1511
|
-
x_qs[i*(2*
|
1512
|
-
x_qs[i*(2*
|
1513
|
-
#endif //
|
1915
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
|
1916
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
|
1917
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1514
1918
|
}
|
1515
1919
|
|
1516
|
-
#
|
1517
|
-
|
1920
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1921
|
+
constexpr int rows_per_warp = warp_size / 2;
|
1518
1922
|
#pragma unroll
|
1519
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
1523
|
-
|
1524
|
-
|
1923
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
1924
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
1925
|
+
// Need if on AMD instead of % because warp_size == 64
|
1926
|
+
// This causes double work and throughput loss (MI300X)
|
1927
|
+
// H100 loses about 100 t/s with 'if' condition over '%'
|
1928
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
|
1929
|
+
if (i < mmq_y) {
|
1930
|
+
#else
|
1931
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
|
1932
|
+
{
|
1933
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
1934
|
+
if (need_check) {
|
1935
|
+
i = min(i, i_max);
|
1936
|
+
}
|
1525
1937
|
|
1526
|
-
|
1938
|
+
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
|
1527
1939
|
|
1528
|
-
|
1529
|
-
|
1940
|
+
const int * scales = (const int *) bxi->scales;
|
1941
|
+
const int ksc = threadIdx.x % 2;
|
1530
1942
|
|
1531
|
-
|
1532
|
-
|
1943
|
+
const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
|
1944
|
+
const int m32 = unpack_scales_q45_K(scales, ksc + 2);
|
1533
1945
|
|
1534
|
-
|
1535
|
-
|
1946
|
+
const uint8_t * sc8 = (const uint8_t *) &sc32;
|
1947
|
+
const uint8_t * m8 = (const uint8_t *) &m32;
|
1536
1948
|
|
1537
|
-
|
1949
|
+
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
|
1538
1950
|
|
1539
1951
|
#pragma unroll
|
1540
|
-
|
1541
|
-
|
1952
|
+
for (int l = 0; l < int(sizeof(int)); ++l) {
|
1953
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
|
1954
|
+
}
|
1542
1955
|
}
|
1543
1956
|
}
|
1544
|
-
|
1545
1957
|
#else
|
1546
|
-
|
1547
1958
|
#pragma unroll
|
1548
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
1549
|
-
int i = (i0 + threadIdx.y*
|
1959
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
1960
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
1550
1961
|
|
1551
1962
|
if (need_check) {
|
1552
1963
|
i = min(i, i_max);
|
@@ -1557,9 +1968,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1557
1968
|
x_dm[i] = bxi->dm;
|
1558
1969
|
}
|
1559
1970
|
|
1971
|
+
constexpr int rows_per_warp = warp_size / 4;
|
1560
1972
|
#pragma unroll
|
1561
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
1562
|
-
int i = (i0 + threadIdx.y*
|
1973
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
1974
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
1563
1975
|
|
1564
1976
|
if (need_check) {
|
1565
1977
|
i = min(i, i_max);
|
@@ -1569,17 +1981,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1569
1981
|
|
1570
1982
|
const int * scales = (const int *) bxi->scales;
|
1571
1983
|
|
1572
|
-
const int ksc = threadIdx.x % (
|
1984
|
+
const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
|
1573
1985
|
const int scales8 = unpack_scales_q45_K(scales, ksc);
|
1574
1986
|
|
1575
|
-
x_sc[i*(
|
1987
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
1576
1988
|
}
|
1577
|
-
#endif //
|
1989
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1578
1990
|
}
|
1579
1991
|
|
1580
|
-
template <int mmq_x, int mmq_y
|
1992
|
+
template <int mmq_x, int mmq_y>
|
1581
1993
|
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
1582
1994
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
1995
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
1996
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1583
1997
|
|
1584
1998
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
1585
1999
|
const int * x_qs = (const int *) x;
|
@@ -1589,7 +2003,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
1589
2003
|
const half2 * y_ds = (const half2 *) y;
|
1590
2004
|
|
1591
2005
|
// #pragma unroll
|
1592
|
-
for (int k01 = 0; k01 <
|
2006
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
|
1593
2007
|
const int k0 = k00 + k01;
|
1594
2008
|
|
1595
2009
|
#pragma unroll
|
@@ -1597,36 +2011,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
1597
2011
|
const int j = j0 + threadIdx.y;
|
1598
2012
|
|
1599
2013
|
#pragma unroll
|
1600
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
2014
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
1601
2015
|
const int i = i0 + threadIdx.x;
|
1602
2016
|
|
1603
|
-
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (
|
2017
|
+
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
|
1604
2018
|
|
1605
|
-
sum[j0/nwarps*mmq_y/
|
1606
|
-
&x_qs[i*(QR5_K*
|
2019
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
|
2020
|
+
&x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
|
1607
2021
|
x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
1608
2022
|
}
|
1609
2023
|
}
|
1610
2024
|
}
|
1611
2025
|
}
|
1612
2026
|
|
1613
|
-
template <int mmq_y,
|
2027
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
1614
2028
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2029
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2030
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1615
2031
|
|
1616
|
-
#
|
2032
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1617
2033
|
int * x_qs = (int *) x_tile;
|
1618
|
-
float * x_df = (float *) (x_qs +
|
1619
|
-
int * x_sc = (int *) (x_df +
|
2034
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
2035
|
+
int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
|
1620
2036
|
#else
|
1621
2037
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
1622
2038
|
int * x_qs = (int *) x_tile;
|
1623
2039
|
float * x_df = (float *) (x_qs + txs.qs);
|
1624
2040
|
int * x_sc = (int *) (x_df + txs.dm);
|
1625
|
-
#endif //
|
2041
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2042
|
+
|
2043
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
|
2044
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2045
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
1626
2046
|
|
1627
2047
|
#pragma unroll
|
1628
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
1629
|
-
int i = i0 + threadIdx.y;
|
2048
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
2049
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
1630
2050
|
|
1631
2051
|
if (need_check) {
|
1632
2052
|
i = min(i, i_max);
|
@@ -1634,67 +2054,67 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1634
2054
|
|
1635
2055
|
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
1636
2056
|
|
1637
|
-
const int ql = get_int_b2(bxi->ql,
|
2057
|
+
const int ql = get_int_b2(bxi->ql, txi);
|
1638
2058
|
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
1639
2059
|
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
1640
2060
|
|
1641
|
-
const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (
|
1642
|
-
const int qh0 = ((qh >> ((
|
1643
|
-
const int qh1 = (qh >> ((
|
2061
|
+
const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
|
2062
|
+
const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
|
2063
|
+
const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
|
1644
2064
|
|
1645
|
-
const int kq0 = 2*
|
1646
|
-
const int kq1 = 2*
|
2065
|
+
const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
|
2066
|
+
const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
|
1647
2067
|
|
1648
|
-
#
|
2068
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1649
2069
|
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
1650
2070
|
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
1651
2071
|
#else
|
1652
|
-
x_qs[i*(2*
|
1653
|
-
x_qs[i*(2*
|
1654
|
-
#endif //
|
2072
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
2073
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
2074
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1655
2075
|
}
|
1656
2076
|
|
1657
|
-
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
1658
|
-
const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
1659
|
-
|
1660
2077
|
#pragma unroll
|
1661
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
1662
|
-
int i = (i0 + threadIdx.y
|
2078
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
2079
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
1663
2080
|
|
1664
2081
|
if (need_check) {
|
1665
2082
|
i = min(i, i_max);
|
1666
2083
|
}
|
1667
2084
|
|
1668
|
-
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride
|
2085
|
+
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
1669
2086
|
|
1670
|
-
#
|
1671
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q6_K
|
2087
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2088
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
|
1672
2089
|
#else
|
1673
|
-
x_df[i*(
|
1674
|
-
#endif //
|
2090
|
+
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
|
2091
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1675
2092
|
}
|
1676
2093
|
|
2094
|
+
constexpr int rows_per_warp = warp_size / 4;
|
1677
2095
|
#pragma unroll
|
1678
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
1679
|
-
int i = (i0 + threadIdx.y
|
2096
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
2097
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
1680
2098
|
|
1681
2099
|
if (need_check) {
|
1682
2100
|
i = min(i, i_max);
|
1683
2101
|
}
|
1684
2102
|
|
1685
|
-
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (
|
2103
|
+
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
|
1686
2104
|
|
1687
|
-
#
|
1688
|
-
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x
|
2105
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2106
|
+
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
|
1689
2107
|
#else
|
1690
|
-
x_sc[i*(
|
1691
|
-
#endif //
|
2108
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
|
2109
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1692
2110
|
}
|
1693
2111
|
}
|
1694
2112
|
|
1695
|
-
template <int mmq_x, int mmq_y
|
2113
|
+
template <int mmq_x, int mmq_y>
|
1696
2114
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
1697
2115
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
2116
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2117
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1698
2118
|
|
1699
2119
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
1700
2120
|
const int * x_qs = (const int *) x;
|
@@ -1704,7 +2124,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
1704
2124
|
const float * y_df = (const float *) y;
|
1705
2125
|
|
1706
2126
|
// #pragma unroll
|
1707
|
-
for (int k01 = 0; k01 <
|
2127
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
|
1708
2128
|
const int k0 = k00 + k01;
|
1709
2129
|
|
1710
2130
|
#pragma unroll
|
@@ -1712,23 +2132,74 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
1712
2132
|
const int j = j0 + threadIdx.y;
|
1713
2133
|
|
1714
2134
|
#pragma unroll
|
1715
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
2135
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
1716
2136
|
const int i = i0 + threadIdx.x;
|
1717
2137
|
|
1718
|
-
const int8_t * sc = ((const int8_t *) &x_sc[i * (
|
2138
|
+
const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
|
1719
2139
|
|
1720
|
-
sum[j0/nwarps*mmq_y/
|
1721
|
-
&x_qs[i*(QR6_K*
|
1722
|
-
x_df[i*(
|
2140
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
|
2141
|
+
&x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
|
2142
|
+
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
1723
2143
|
}
|
1724
2144
|
}
|
1725
2145
|
}
|
1726
2146
|
}
|
1727
2147
|
|
1728
|
-
template <int mmq_x, int mmq_y
|
2148
|
+
template <int mmq_x, int mmq_y>
|
1729
2149
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
1730
2150
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
1731
|
-
#
|
2151
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
2152
|
+
typedef tile<16, 8, int> tile_A;
|
2153
|
+
typedef tile<16, 8, int> tile_B;
|
2154
|
+
typedef tile<16, 16, int> tile_C;
|
2155
|
+
typedef tile<64, 2, int> tile_load;
|
2156
|
+
|
2157
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
2158
|
+
constexpr int rows_per_warp = granularity;
|
2159
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
2160
|
+
|
2161
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
2162
|
+
|
2163
|
+
const int * x_qs = (const int *) x;
|
2164
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
2165
|
+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
2166
|
+
const int * y_qs = (const int *) y + 4;
|
2167
|
+
const float * y_df = (const float *) y;
|
2168
|
+
|
2169
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
2170
|
+
|
2171
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
2172
|
+
const int k0 = k00 + k01;
|
2173
|
+
|
2174
|
+
tile_A A[ntx];
|
2175
|
+
#pragma unroll
|
2176
|
+
for (int n = 0; n < ntx; ++n) {
|
2177
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
2178
|
+
}
|
2179
|
+
|
2180
|
+
#pragma unroll
|
2181
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
2182
|
+
tile_B B[1];
|
2183
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
2184
|
+
|
2185
|
+
const int j = j0 + tile_C::get_j(0);
|
2186
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
2187
|
+
|
2188
|
+
#pragma unroll
|
2189
|
+
for (int n = 0; n < ntx; ++n) {
|
2190
|
+
tile_C C;
|
2191
|
+
mma(C, A[n], B[0]);
|
2192
|
+
|
2193
|
+
#pragma unroll
|
2194
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
2195
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
2196
|
+
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
|
2197
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
|
2198
|
+
}
|
2199
|
+
}
|
2200
|
+
}
|
2201
|
+
}
|
2202
|
+
#elif defined(TURING_MMA_AVAILABLE)
|
1732
2203
|
|
1733
2204
|
typedef tile<16, 4, int> tile_A;
|
1734
2205
|
typedef tile< 8, 4, int> tile_B;
|
@@ -1738,11 +2209,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
1738
2209
|
constexpr int rows_per_warp = 2 * granularity;
|
1739
2210
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
1740
2211
|
|
1741
|
-
y += (threadIdx.y % ntx) * (
|
2212
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
1742
2213
|
|
1743
2214
|
const int * x_qs = (const int *) x;
|
1744
|
-
const float * x_df = (const float *) x_qs +
|
1745
|
-
const int * x_sc = (const int *) x_df +
|
2215
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
2216
|
+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
1746
2217
|
const int * y_qs = (const int *) y + 4;
|
1747
2218
|
const float * y_df = (const float *) y;
|
1748
2219
|
|
@@ -1755,7 +2226,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
1755
2226
|
#pragma unroll
|
1756
2227
|
for (int n = 0; n < ntx; ++n) {
|
1757
2228
|
#pragma unroll
|
1758
|
-
for (int k01 = 0; k01 <
|
2229
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
1759
2230
|
const int k0 = k00 + k01;
|
1760
2231
|
|
1761
2232
|
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
@@ -1763,7 +2234,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
1763
2234
|
}
|
1764
2235
|
|
1765
2236
|
#pragma unroll
|
1766
|
-
for (int k01 = 0; k01 <
|
2237
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
|
1767
2238
|
const int k0 = k00 + k01;
|
1768
2239
|
|
1769
2240
|
#pragma unroll
|
@@ -1793,7 +2264,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
1793
2264
|
float tmp[ntx][tile_C::ne] = {{0.0f}};
|
1794
2265
|
|
1795
2266
|
#pragma unroll
|
1796
|
-
for (int k01 = 0; k01 <
|
2267
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
1797
2268
|
tile_B B[2];
|
1798
2269
|
float dB[tile_C::ne/2];
|
1799
2270
|
|
@@ -1830,29 +2301,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
1830
2301
|
}
|
1831
2302
|
}
|
1832
2303
|
#else
|
1833
|
-
|
2304
|
+
GGML_UNUSED_VARS(x, y, sum, k00);
|
1834
2305
|
NO_DEVICE_CODE;
|
1835
|
-
#endif //
|
2306
|
+
#endif // AMD_MFMA_AVAILABLE
|
1836
2307
|
}
|
1837
2308
|
|
1838
|
-
template <int mmq_y,
|
2309
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
1839
2310
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2311
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2312
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1840
2313
|
|
1841
|
-
#
|
2314
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1842
2315
|
int * x_qs = (int *) x_tile;
|
1843
|
-
float * x_df = (float *) (x_qs +
|
2316
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
1844
2317
|
#else
|
1845
2318
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
1846
2319
|
int * x_qs = (int *) x_tile;
|
1847
2320
|
float * x_df = (float *) (x_qs + txs.qs);
|
1848
|
-
#endif //
|
2321
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1849
2322
|
|
1850
|
-
|
1851
|
-
|
2323
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
|
2324
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2325
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
2326
|
+
const int kbx = txi / QI4_NL;
|
2327
|
+
const int kqsx = txi % QI4_NL;
|
1852
2328
|
|
1853
2329
|
#pragma unroll
|
1854
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
1855
|
-
int i = i0 + threadIdx.y;
|
2330
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
2331
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
1856
2332
|
|
1857
2333
|
if (need_check) {
|
1858
2334
|
i = min(i, i_max);
|
@@ -1861,23 +2337,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1861
2337
|
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
1862
2338
|
|
1863
2339
|
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
1864
|
-
const int2 v = get_int_from_table_16(aux_q4);
|
1865
|
-
const int k0 =
|
1866
|
-
|
1867
|
-
|
1868
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 +
|
2340
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
2341
|
+
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
2342
|
+
|
2343
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2344
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
2345
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
|
1869
2346
|
#else
|
1870
|
-
x_qs[i*(2*
|
1871
|
-
x_qs[i*(2*
|
1872
|
-
#endif //
|
2347
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
2348
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
|
2349
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1873
2350
|
}
|
1874
2351
|
|
1875
|
-
|
2352
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
|
2353
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
1876
2354
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
1877
2355
|
|
1878
2356
|
#pragma unroll
|
1879
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
1880
|
-
int i = i0 + threadIdx.y *
|
2357
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
2358
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
1881
2359
|
|
1882
2360
|
if (need_check) {
|
1883
2361
|
i = min(i, i_max);
|
@@ -1885,31 +2363,35 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1885
2363
|
|
1886
2364
|
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
1887
2365
|
|
1888
|
-
#
|
1889
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
2366
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2367
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
1890
2368
|
#else
|
1891
|
-
x_df[i*(
|
1892
|
-
#endif //
|
2369
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
|
2370
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1893
2371
|
}
|
1894
2372
|
}
|
1895
2373
|
|
1896
|
-
template <int mmq_y,
|
2374
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
|
1897
2375
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2376
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2377
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1898
2378
|
|
1899
|
-
#
|
2379
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1900
2380
|
int * x_qs = (int *) x_tile;
|
1901
|
-
float * x_df = (float *) (x_qs +
|
2381
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
1902
2382
|
#else
|
1903
2383
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
|
1904
2384
|
int * x_qs = (int *) x_tile;
|
1905
2385
|
float * x_df = (float *) (x_qs + txs.qs);
|
1906
|
-
#endif //
|
2386
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1907
2387
|
|
1908
|
-
|
2388
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
|
2389
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2390
|
+
const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
1909
2391
|
|
1910
2392
|
#pragma unroll
|
1911
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
1912
|
-
int i = i0 + threadIdx.y*
|
2393
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
2394
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
1913
2395
|
|
1914
2396
|
if (need_check) {
|
1915
2397
|
i = min(i, i_max);
|
@@ -1932,42 +2414,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1932
2414
|
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
1933
2415
|
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
1934
2416
|
|
1935
|
-
#
|
2417
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1936
2418
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
1937
2419
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
|
1938
2420
|
#else
|
1939
|
-
x_qs[i*(2*
|
1940
|
-
x_qs[i*(2*
|
1941
|
-
#endif //
|
2421
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
|
2422
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
|
2423
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1942
2424
|
}
|
1943
2425
|
|
1944
2426
|
const int ls = aux32 >> 28;
|
1945
2427
|
const float d = bxi->d;
|
1946
|
-
#
|
1947
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
2428
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2429
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
1948
2430
|
#else
|
1949
|
-
x_df[i*(
|
1950
|
-
#endif //
|
2431
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
2432
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1951
2433
|
}
|
1952
2434
|
}
|
1953
2435
|
|
1954
|
-
template <int mmq_y,
|
2436
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
|
1955
2437
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2438
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2439
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
1956
2440
|
|
1957
|
-
#
|
2441
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1958
2442
|
int * x_qs = (int *) x_tile;
|
1959
|
-
float * x_df = (float *) (x_qs +
|
2443
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
1960
2444
|
#else
|
1961
2445
|
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
1962
2446
|
int * x_qs = (int *) x_tile;
|
1963
2447
|
float * x_df = (float *) (x_qs + txs.qs);
|
1964
|
-
#endif //
|
2448
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1965
2449
|
|
1966
|
-
|
2450
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
|
2451
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2452
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
1967
2453
|
|
1968
2454
|
#pragma unroll
|
1969
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
1970
|
-
int i = i0 + threadIdx.y*
|
2455
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
2456
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
1971
2457
|
|
1972
2458
|
if (need_check) {
|
1973
2459
|
i = min(i, i_max);
|
@@ -1986,44 +2472,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
1986
2472
|
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
1987
2473
|
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
1988
2474
|
|
1989
|
-
#
|
2475
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1990
2476
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
1991
2477
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
1992
2478
|
#else
|
1993
|
-
x_qs[i*(2*
|
1994
|
-
x_qs[i*(2*
|
1995
|
-
#endif //
|
2479
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
2480
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
2481
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
1996
2482
|
}
|
1997
2483
|
|
1998
2484
|
const int ls = bxi->scales[kqsx];
|
1999
2485
|
const float d = bxi->d;
|
2000
|
-
#
|
2001
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
2002
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
2486
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2487
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
2488
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
2003
2489
|
#else
|
2004
|
-
x_df[i*(2*
|
2005
|
-
x_df[i*(2*
|
2006
|
-
#endif //
|
2490
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
2491
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
2492
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2007
2493
|
}
|
2008
2494
|
}
|
2009
2495
|
|
2010
|
-
template <int mmq_y,
|
2496
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
|
2011
2497
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2498
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2499
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
2012
2500
|
|
2013
|
-
#
|
2501
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2014
2502
|
int * x_qs = (int *) x_tile;
|
2015
|
-
float * x_df = (float *) (x_qs +
|
2503
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
2016
2504
|
#else
|
2017
2505
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
|
2018
2506
|
int * x_qs = (int *) x_tile;
|
2019
2507
|
float * x_df = (float *) (x_qs + txs.qs);
|
2020
|
-
#endif //
|
2508
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2021
2509
|
|
2022
|
-
|
2510
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
|
2511
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2512
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
2023
2513
|
|
2024
2514
|
#pragma unroll
|
2025
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
2026
|
-
int i = i0 + threadIdx.y*
|
2515
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
2516
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
2027
2517
|
|
2028
2518
|
if (need_check) {
|
2029
2519
|
i = min(i, i_max);
|
@@ -2049,44 +2539,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
2049
2539
|
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
|
2050
2540
|
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
|
2051
2541
|
|
2052
|
-
#
|
2542
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2053
2543
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
2054
2544
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
2055
2545
|
#else
|
2056
|
-
x_qs[i*(2*
|
2057
|
-
x_qs[i*(2*
|
2058
|
-
#endif //
|
2546
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
2547
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
2548
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2059
2549
|
}
|
2060
2550
|
|
2061
2551
|
const int ls = bxi->scales[kqsx];
|
2062
2552
|
const float d = bxi->d;
|
2063
|
-
#
|
2064
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
2065
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
2553
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2554
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
2555
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
2066
2556
|
#else
|
2067
|
-
x_df[i*(2*
|
2068
|
-
x_df[i*(2*
|
2069
|
-
#endif //
|
2557
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
2558
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
2559
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2070
2560
|
}
|
2071
2561
|
}
|
2072
2562
|
|
2073
|
-
template <int mmq_y,
|
2563
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
|
2074
2564
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2565
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2566
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
2075
2567
|
|
2076
|
-
#
|
2568
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2077
2569
|
int * x_qs = (int *) x_tile;
|
2078
|
-
float * x_df = (float *) (x_qs +
|
2570
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
2079
2571
|
#else
|
2080
2572
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
|
2081
2573
|
int * x_qs = (int *) x_tile;
|
2082
2574
|
float * x_df = (float *) (x_qs + txs.qs);
|
2083
|
-
#endif //
|
2575
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2084
2576
|
|
2085
|
-
|
2577
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
|
2578
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2579
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
2086
2580
|
|
2087
2581
|
#pragma unroll
|
2088
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
2089
|
-
int i = i0 + threadIdx.y*
|
2582
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
2583
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
2090
2584
|
|
2091
2585
|
if (need_check) {
|
2092
2586
|
i = min(i, i_max);
|
@@ -2107,42 +2601,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
2107
2601
|
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
2108
2602
|
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
2109
2603
|
|
2110
|
-
#
|
2604
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2111
2605
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
2112
2606
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
|
2113
2607
|
#else
|
2114
|
-
x_qs[i*(2*
|
2115
|
-
x_qs[i*(2*
|
2116
|
-
#endif //
|
2608
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
2609
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
2610
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2117
2611
|
}
|
2118
2612
|
|
2119
2613
|
const int ls = aux32 >> 28;
|
2120
2614
|
const float d = bxi->d;
|
2121
|
-
#
|
2122
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
2615
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2616
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
|
2123
2617
|
#else
|
2124
|
-
x_df[i*(
|
2125
|
-
#endif //
|
2618
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
|
2619
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2126
2620
|
}
|
2127
2621
|
}
|
2128
2622
|
|
2129
|
-
template <int mmq_y,
|
2623
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
|
2130
2624
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2625
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2626
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
2131
2627
|
|
2132
|
-
#
|
2628
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2133
2629
|
int * x_qs = (int *) x_tile;
|
2134
|
-
float * x_df = (float *) (x_qs +
|
2630
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
2135
2631
|
#else
|
2136
2632
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
2137
2633
|
int * x_qs = (int *) x_tile;
|
2138
2634
|
float * x_df = (float *) (x_qs + txs.qs);
|
2139
|
-
#endif //
|
2635
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2140
2636
|
|
2141
|
-
|
2637
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
|
2638
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2639
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
2142
2640
|
|
2143
2641
|
#pragma unroll
|
2144
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
2145
|
-
int i = i0 + threadIdx.y*
|
2642
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
2643
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
2146
2644
|
|
2147
2645
|
if (need_check) {
|
2148
2646
|
i = min(i, i_max);
|
@@ -2170,42 +2668,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
2170
2668
|
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
2171
2669
|
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
2172
2670
|
|
2173
|
-
#
|
2671
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2174
2672
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
|
2175
2673
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
|
2176
2674
|
#else
|
2177
|
-
x_qs[i*(2*
|
2178
|
-
x_qs[i*(2*
|
2179
|
-
#endif //
|
2675
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
|
2676
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
|
2677
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2180
2678
|
}
|
2181
2679
|
|
2182
2680
|
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
|
2183
2681
|
const float d = bxi->d;
|
2184
|
-
#
|
2185
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
2682
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2683
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
|
2186
2684
|
#else
|
2187
|
-
x_df[i*(
|
2188
|
-
#endif //
|
2685
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
|
2686
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2189
2687
|
}
|
2190
2688
|
}
|
2191
2689
|
|
2192
|
-
template <int mmq_y,
|
2690
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
|
2193
2691
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2692
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2693
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
2194
2694
|
|
2195
|
-
#
|
2695
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2196
2696
|
int * x_qs = (int *) x_tile;
|
2197
|
-
half2 * x_ds = (half2 *) (x_qs +
|
2697
|
+
half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
2198
2698
|
#else
|
2199
2699
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
2200
2700
|
int * x_qs = (int *) x_tile;
|
2201
2701
|
half2 * x_ds = (half2 *) (x_qs + txs.qs);
|
2202
|
-
#endif //
|
2702
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2203
2703
|
|
2204
|
-
|
2704
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
|
2705
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2706
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
2205
2707
|
|
2206
2708
|
#pragma unroll
|
2207
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
2208
|
-
int i = i0 + threadIdx.y*
|
2709
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
2710
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
2209
2711
|
|
2210
2712
|
if (need_check) {
|
2211
2713
|
i = min(i, i_max);
|
@@ -2225,66 +2727,71 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
2225
2727
|
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
|
2226
2728
|
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
|
2227
2729
|
|
2228
|
-
#
|
2730
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2229
2731
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
|
2230
2732
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
|
2231
2733
|
#else
|
2232
|
-
x_qs[i*(2*
|
2233
|
-
x_qs[i*(2*
|
2234
|
-
#endif //
|
2734
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
|
2735
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
|
2736
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2235
2737
|
}
|
2236
2738
|
|
2237
2739
|
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
|
2238
2740
|
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
|
2239
2741
|
|
2240
|
-
#
|
2241
|
-
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1
|
2742
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2743
|
+
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
|
2242
2744
|
#else
|
2243
|
-
x_ds[i*(
|
2244
|
-
#endif //
|
2745
|
+
x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
|
2746
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2245
2747
|
}
|
2246
2748
|
}
|
2247
2749
|
|
2248
|
-
template <int mmq_y,
|
2750
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
2249
2751
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
2752
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2753
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
2250
2754
|
|
2251
|
-
#
|
2755
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2252
2756
|
int * x_qs = (int *) x_tile;
|
2253
|
-
float * x_df = (float *) (x_qs +
|
2757
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
2254
2758
|
#else
|
2255
2759
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
2256
2760
|
int * x_qs = (int *) x_tile;
|
2257
2761
|
float * x_df = (float *) (x_qs + txs.qs);
|
2258
|
-
#endif //
|
2762
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2259
2763
|
|
2260
|
-
|
2261
|
-
|
2764
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
|
2765
|
+
constexpr int nrows = warp_size / threads_per_row;
|
2766
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
2262
2767
|
|
2263
2768
|
#pragma unroll
|
2264
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
2265
|
-
int i = i0 + threadIdx.y;
|
2769
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
2770
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
2266
2771
|
|
2267
2772
|
if (need_check) {
|
2268
2773
|
i = min(i, i_max);
|
2269
2774
|
}
|
2270
2775
|
|
2271
|
-
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride
|
2776
|
+
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
2272
2777
|
|
2273
2778
|
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
2274
|
-
const int2 v = get_int_from_table_16(aux_q4);
|
2275
|
-
const int k0 = 8 * (
|
2276
|
-
|
2779
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
2780
|
+
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
2781
|
+
|
2782
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2277
2783
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
2278
2784
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
2279
2785
|
#else
|
2280
|
-
x_qs[i*(2*
|
2281
|
-
x_qs[i*(2*
|
2282
|
-
#endif //
|
2786
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
2787
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
|
2788
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2283
2789
|
}
|
2284
2790
|
|
2791
|
+
constexpr int rows_per_warp = warp_size / 8;
|
2285
2792
|
#pragma unroll
|
2286
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
2287
|
-
int i = i0 + threadIdx.y *
|
2793
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
2794
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
|
2288
2795
|
|
2289
2796
|
if (need_check) {
|
2290
2797
|
i = min(i, i_max);
|
@@ -2297,18 +2804,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
2297
2804
|
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
2298
2805
|
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
2299
2806
|
|
2300
|
-
#
|
2301
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
2807
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2808
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
2302
2809
|
#else
|
2303
|
-
x_df[i*(
|
2304
|
-
#endif //
|
2810
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
2811
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2305
2812
|
}
|
2306
2813
|
}
|
2307
2814
|
|
2308
|
-
template<int mmq_x, int mmq_y,
|
2815
|
+
template<int mmq_x, int mmq_y, bool need_check>
|
2309
2816
|
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
2310
2817
|
const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
|
2311
2818
|
const int stride, const int i_max, const int j_max) {
|
2819
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2820
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
2821
|
+
|
2312
2822
|
#pragma unroll
|
2313
2823
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
2314
2824
|
const int j = j0 + threadIdx.y;
|
@@ -2318,32 +2828,42 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
|
|
2318
2828
|
}
|
2319
2829
|
|
2320
2830
|
#pragma unroll
|
2321
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
2831
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
2322
2832
|
const int i = i0 + threadIdx.x;
|
2323
2833
|
|
2324
2834
|
if (need_check && i > i_max) {
|
2325
2835
|
continue;
|
2326
2836
|
}
|
2327
2837
|
|
2328
|
-
dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/
|
2838
|
+
dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
2329
2839
|
}
|
2330
2840
|
}
|
2331
2841
|
}
|
2332
2842
|
|
2333
|
-
template<
|
2843
|
+
template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
|
2334
2844
|
static __device__ __forceinline__ void mmq_write_back_mma(
|
2335
2845
|
const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
|
2336
2846
|
const int stride, const int i_max, const int j_max) {
|
2337
|
-
typedef tile<16, 8, int> tile_C;
|
2338
2847
|
|
2339
2848
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
2849
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2850
|
+
|
2851
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
2852
|
+
constexpr int tileC_IJ = mmq_get_granularity_device(0);
|
2853
|
+
typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
|
2854
|
+
constexpr int rows_per_warp = granularity;
|
2855
|
+
#else
|
2856
|
+
typedef tile<16, 8, int> tile_C;
|
2340
2857
|
constexpr int rows_per_warp = 2 * granularity;
|
2858
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
2341
2859
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
2342
2860
|
|
2343
2861
|
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
2344
|
-
#
|
2862
|
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
2345
2863
|
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
2346
|
-
#
|
2864
|
+
#else
|
2865
|
+
GGML_UNUSED(nwarps);
|
2866
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2347
2867
|
|
2348
2868
|
#pragma unroll
|
2349
2869
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
@@ -2371,179 +2891,189 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
2371
2891
|
|
2372
2892
|
// -------------------------------------------------------------------------------------------------------------------------------------
|
2373
2893
|
|
2374
|
-
template <int mmq_x, int mmq_y,
|
2894
|
+
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
|
2375
2895
|
struct mmq_type_traits;
|
2376
2896
|
|
2377
|
-
template <int mmq_x, int mmq_y,
|
2378
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2897
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2898
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
|
2379
2899
|
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
2380
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y,
|
2381
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
2382
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y
|
2900
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
|
2901
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
|
2902
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2383
2903
|
};
|
2384
2904
|
|
2385
|
-
template <int mmq_x, int mmq_y,
|
2386
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2905
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2906
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
|
2387
2907
|
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
2388
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y,
|
2389
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
2390
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y
|
2908
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
|
2909
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
2910
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
|
2391
2911
|
};
|
2392
2912
|
|
2393
|
-
template <int mmq_x, int mmq_y,
|
2394
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2913
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2914
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
|
2395
2915
|
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
2396
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y,
|
2397
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
2398
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
2916
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
|
2917
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
2918
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2399
2919
|
};
|
2400
2920
|
|
2401
|
-
template <int mmq_x, int mmq_y,
|
2402
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2921
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2922
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
|
2403
2923
|
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
2404
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y,
|
2405
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
2406
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y
|
2924
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
|
2925
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
2926
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
|
2407
2927
|
};
|
2408
2928
|
|
2409
|
-
template <int mmq_x, int mmq_y,
|
2410
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2929
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2930
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
2411
2931
|
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
2412
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y,
|
2413
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
2414
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
2932
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
|
2933
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
2934
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2935
|
+
};
|
2936
|
+
|
2937
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2938
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
2939
|
+
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
2940
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
2941
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
2942
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2415
2943
|
};
|
2416
2944
|
|
2417
|
-
template <int mmq_x, int mmq_y,
|
2418
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2945
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2946
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
2419
2947
|
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
2420
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y,
|
2421
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y
|
2422
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y
|
2948
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
|
2949
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
|
2950
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
|
2423
2951
|
};
|
2424
2952
|
|
2425
|
-
template <int mmq_x, int mmq_y,
|
2426
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2953
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2954
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
|
2427
2955
|
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
2428
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y,
|
2429
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
2430
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y
|
2956
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
|
2957
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
2958
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
|
2431
2959
|
};
|
2432
2960
|
|
2433
|
-
template <int mmq_x, int mmq_y,
|
2434
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2961
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2962
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
|
2435
2963
|
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
2436
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y,
|
2437
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
2438
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y
|
2964
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
|
2965
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
2966
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
|
2439
2967
|
};
|
2440
2968
|
|
2441
|
-
template <int mmq_x, int mmq_y,
|
2442
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2969
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2970
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
|
2443
2971
|
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
2444
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y,
|
2445
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
2446
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y
|
2972
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
|
2973
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
2974
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
|
2447
2975
|
};
|
2448
2976
|
|
2449
|
-
template <int mmq_x, int mmq_y,
|
2450
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2977
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2978
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
|
2451
2979
|
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
2452
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y,
|
2453
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y
|
2454
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y
|
2980
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
|
2981
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
|
2982
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
|
2455
2983
|
};
|
2456
2984
|
|
2457
|
-
template <int mmq_x, int mmq_y,
|
2458
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2985
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2986
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
|
2459
2987
|
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
|
2460
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y,
|
2461
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
2462
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
2988
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
|
2989
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
2990
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2463
2991
|
};
|
2464
2992
|
|
2465
|
-
template <int mmq_x, int mmq_y,
|
2466
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
2993
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
2994
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
|
2467
2995
|
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
|
2468
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y,
|
2469
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
2470
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y
|
2996
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
|
2997
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
2998
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
2471
2999
|
};
|
2472
3000
|
|
2473
|
-
template <int mmq_x, int mmq_y,
|
2474
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
3001
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
3002
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
|
2475
3003
|
static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
|
2476
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y,
|
2477
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
2478
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y
|
3004
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
|
3005
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
3006
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
2479
3007
|
};
|
2480
3008
|
|
2481
|
-
template <int mmq_x, int mmq_y,
|
2482
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
3009
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
3010
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
|
2483
3011
|
static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
|
2484
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y,
|
2485
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
2486
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
3012
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
|
3013
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
3014
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2487
3015
|
};
|
2488
3016
|
|
2489
|
-
template <int mmq_x, int mmq_y,
|
2490
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
3017
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
3018
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
|
2491
3019
|
static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
|
2492
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y,
|
2493
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
2494
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
3020
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
|
3021
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
3022
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2495
3023
|
};
|
2496
3024
|
|
2497
|
-
template <int mmq_x, int mmq_y,
|
2498
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
3025
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
3026
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
|
2499
3027
|
static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
|
2500
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y,
|
2501
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
2502
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y
|
3028
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
|
3029
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
3030
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
|
2503
3031
|
};
|
2504
3032
|
|
2505
|
-
template <int mmq_x, int mmq_y,
|
2506
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
3033
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
3034
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
|
2507
3035
|
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
|
2508
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y,
|
2509
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
2510
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
3036
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
|
3037
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
3038
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2511
3039
|
};
|
2512
3040
|
|
2513
|
-
template <int mmq_x, int mmq_y,
|
2514
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
3041
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
3042
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
|
2515
3043
|
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
2516
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y,
|
2517
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
2518
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
3044
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
|
3045
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
3046
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
2519
3047
|
};
|
2520
3048
|
|
2521
|
-
template <ggml_type type, int mmq_x,
|
3049
|
+
template <ggml_type type, int mmq_x, bool need_check, bool fixup>
|
2522
3050
|
static __device__ __forceinline__ void mul_mat_q_process_tile(
|
2523
3051
|
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
|
2524
3052
|
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
2525
3053
|
const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
2526
3054
|
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
|
2527
3055
|
|
3056
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
3057
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
2528
3058
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
2529
3059
|
constexpr int mmq_y = get_mmq_y_device();
|
2530
|
-
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y,
|
3060
|
+
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
|
2531
3061
|
|
2532
3062
|
extern __shared__ int data_mul_mat_q[];
|
2533
3063
|
int * tile_y = data_mul_mat_q + mmq_x;
|
2534
|
-
int * tile_x = tile_y + GGML_PAD(mmq_x*
|
3064
|
+
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
|
2535
3065
|
|
2536
|
-
#
|
2537
|
-
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y,
|
2538
|
-
constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y,
|
3066
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
3067
|
+
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
|
3068
|
+
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
|
2539
3069
|
#else
|
2540
|
-
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y,
|
2541
|
-
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y,
|
2542
|
-
#endif //
|
3070
|
+
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
|
3071
|
+
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
3072
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
2543
3073
|
|
2544
3074
|
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
2545
3075
|
|
2546
|
-
float sum[mmq_x*mmq_y / (nwarps*
|
3076
|
+
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
2547
3077
|
|
2548
3078
|
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
2549
3079
|
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
@@ -2551,8 +3081,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
2551
3081
|
{
|
2552
3082
|
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
|
2553
3083
|
#pragma unroll
|
2554
|
-
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*
|
2555
|
-
int l = l0 + threadIdx.y*
|
3084
|
+
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
3085
|
+
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
2556
3086
|
|
2557
3087
|
tile_y[l] = by0[l];
|
2558
3088
|
}
|
@@ -2567,8 +3097,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
2567
3097
|
{
|
2568
3098
|
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
|
2569
3099
|
#pragma unroll
|
2570
|
-
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*
|
2571
|
-
int l = l0 + threadIdx.y*
|
3100
|
+
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
3101
|
+
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
2572
3102
|
|
2573
3103
|
tile_y[l] = by0[l];
|
2574
3104
|
}
|
@@ -2576,7 +3106,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
2576
3106
|
|
2577
3107
|
__syncthreads();
|
2578
3108
|
|
2579
|
-
vec_dot(tile_x, tile_y, sum,
|
3109
|
+
vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
|
2580
3110
|
|
2581
3111
|
__syncthreads();
|
2582
3112
|
}
|
@@ -2591,24 +3121,25 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
2591
3121
|
|
2592
3122
|
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
|
2593
3123
|
|
2594
|
-
template <ggml_type type, int mmq_x,
|
2595
|
-
#if defined(GGML_USE_HIP)
|
3124
|
+
template <ggml_type type, int mmq_x, bool need_check>
|
3125
|
+
#if defined(GGML_USE_HIP)
|
2596
3126
|
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
2597
|
-
__launch_bounds__(
|
3127
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
|
2598
3128
|
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
2599
3129
|
#else
|
2600
3130
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
2601
|
-
__launch_bounds__(
|
3131
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
|
2602
3132
|
#else
|
2603
|
-
__launch_bounds__(
|
3133
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
|
2604
3134
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
2605
|
-
#endif // defined(GGML_USE_HIP)
|
3135
|
+
#endif // defined(GGML_USE_HIP)
|
2606
3136
|
static __global__ void mul_mat_q(
|
2607
3137
|
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
2608
3138
|
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
2609
3139
|
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
2610
3140
|
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
2611
|
-
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst
|
3141
|
+
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
3142
|
+
const int ncols_max) {
|
2612
3143
|
|
2613
3144
|
// Skip unused template specializations for faster compilation:
|
2614
3145
|
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
@@ -2616,10 +3147,13 @@ static __global__ void mul_mat_q(
|
|
2616
3147
|
return;
|
2617
3148
|
}
|
2618
3149
|
|
3150
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
3151
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
3152
|
+
|
2619
3153
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
2620
3154
|
constexpr int mmq_y = get_mmq_y_device();
|
2621
3155
|
|
2622
|
-
const int ntx = (
|
3156
|
+
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
|
2623
3157
|
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
2624
3158
|
|
2625
3159
|
// Initialize the ids for writing back data with just the index.
|
@@ -2627,10 +3161,10 @@ static __global__ void mul_mat_q(
|
|
2627
3161
|
// For MoE the correct indices are loaded from ids_dst.
|
2628
3162
|
extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
|
2629
3163
|
#pragma unroll
|
2630
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
2631
|
-
const int j = j0 + threadIdx.y*
|
3164
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
3165
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
2632
3166
|
|
2633
|
-
if (j0 + nwarps*
|
3167
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
2634
3168
|
break;
|
2635
3169
|
}
|
2636
3170
|
|
@@ -2638,8 +3172,8 @@ static __global__ void mul_mat_q(
|
|
2638
3172
|
}
|
2639
3173
|
__syncthreads();
|
2640
3174
|
|
2641
|
-
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
2642
|
-
#if (defined(GGML_USE_HIP) && defined(
|
3175
|
+
// On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
3176
|
+
#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
2643
3177
|
{
|
2644
3178
|
const int wt = blockIdx.z / nchannels_y;
|
2645
3179
|
const int zt = blockIdx.z - wt*nchannels_y;
|
@@ -2667,10 +3201,10 @@ static __global__ void mul_mat_q(
|
|
2667
3201
|
|
2668
3202
|
// __syncthreads(); // There is no previous tile that could cause a race condition.
|
2669
3203
|
#pragma unroll
|
2670
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
2671
|
-
const int j = j0 + threadIdx.y*
|
3204
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
3205
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
2672
3206
|
|
2673
|
-
if (j0 + nwarps*
|
3207
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
2674
3208
|
break;
|
2675
3209
|
}
|
2676
3210
|
|
@@ -2688,12 +3222,12 @@ static __global__ void mul_mat_q(
|
|
2688
3222
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
2689
3223
|
|
2690
3224
|
constexpr bool fixup = false;
|
2691
|
-
mul_mat_q_process_tile<type, mmq_x,
|
3225
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
2692
3226
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
2693
3227
|
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
|
2694
3228
|
return;
|
2695
3229
|
}
|
2696
|
-
#endif // (defined(GGML_USE_HIP) && defined(
|
3230
|
+
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
2697
3231
|
|
2698
3232
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
2699
3233
|
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
@@ -2745,10 +3279,10 @@ static __global__ void mul_mat_q(
|
|
2745
3279
|
|
2746
3280
|
__syncthreads();
|
2747
3281
|
#pragma unroll
|
2748
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
2749
|
-
const int j = j0 + threadIdx.y*
|
3282
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
3283
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
2750
3284
|
|
2751
|
-
if (j0 + nwarps*
|
3285
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
2752
3286
|
break;
|
2753
3287
|
}
|
2754
3288
|
|
@@ -2766,7 +3300,7 @@ static __global__ void mul_mat_q(
|
|
2766
3300
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
2767
3301
|
|
2768
3302
|
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
2769
|
-
mul_mat_q_process_tile<type, mmq_x,
|
3303
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
2770
3304
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
2771
3305
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
2772
3306
|
|
@@ -2812,10 +3346,10 @@ static __global__ void mul_mat_q(
|
|
2812
3346
|
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
|
2813
3347
|
__syncthreads();
|
2814
3348
|
#pragma unroll
|
2815
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
2816
|
-
const int j = j0 + threadIdx.y*
|
3349
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
3350
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
2817
3351
|
|
2818
|
-
if (j0 + nwarps*
|
3352
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
2819
3353
|
break;
|
2820
3354
|
}
|
2821
3355
|
|
@@ -2833,25 +3367,29 @@ static __global__ void mul_mat_q(
|
|
2833
3367
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
2834
3368
|
|
2835
3369
|
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
2836
|
-
mul_mat_q_process_tile<type, mmq_x,
|
3370
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
2837
3371
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
2838
3372
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
2839
3373
|
}
|
2840
3374
|
|
2841
3375
|
|
2842
|
-
template <ggml_type type, int mmq_x,
|
3376
|
+
template <ggml_type type, int mmq_x, bool need_check>
|
2843
3377
|
static __global__ void mul_mat_q_stream_k_fixup(
|
2844
3378
|
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
2845
3379
|
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
2846
|
-
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst
|
3380
|
+
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
|
3381
|
+
const int ncols_max) {
|
2847
3382
|
constexpr int mmq_y = get_mmq_y_device();
|
2848
3383
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
2849
3384
|
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
2850
3385
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
2851
3386
|
|
2852
|
-
|
3387
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
3388
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
2853
3389
|
|
2854
|
-
|
3390
|
+
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
3391
|
+
|
3392
|
+
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
|
2855
3393
|
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
2856
3394
|
|
2857
3395
|
const int bidx0 = blockIdx.x;
|
@@ -2893,10 +3431,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
2893
3431
|
const int j = j0 + threadIdx.y;
|
2894
3432
|
|
2895
3433
|
#pragma unroll
|
2896
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
3434
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
2897
3435
|
const int i = i0 + threadIdx.x;
|
2898
3436
|
|
2899
|
-
sum[(j0/nwarps) * (mmq_y/
|
3437
|
+
sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
2900
3438
|
}
|
2901
3439
|
}
|
2902
3440
|
|
@@ -2937,14 +3475,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
2937
3475
|
}
|
2938
3476
|
|
2939
3477
|
#pragma unroll
|
2940
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
3478
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
2941
3479
|
const int i = i0 + threadIdx.x;
|
2942
3480
|
|
2943
3481
|
if (need_check && i > i_max) {
|
2944
3482
|
continue;
|
2945
3483
|
}
|
2946
3484
|
|
2947
|
-
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/
|
3485
|
+
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
2948
3486
|
}
|
2949
3487
|
}
|
2950
3488
|
return;
|
@@ -2955,7 +3493,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
2955
3493
|
const int col_high = expert_bounds[zt + 1];
|
2956
3494
|
const int col_diff = col_high - col_low;
|
2957
3495
|
|
2958
|
-
for (int j = threadIdx.y*
|
3496
|
+
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
|
2959
3497
|
ids_dst_shared[j] = ids_dst[col_low + j];
|
2960
3498
|
}
|
2961
3499
|
__syncthreads();
|
@@ -2975,14 +3513,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
2975
3513
|
}
|
2976
3514
|
|
2977
3515
|
#pragma unroll
|
2978
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
3516
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
2979
3517
|
const int i = i0 + threadIdx.x;
|
2980
3518
|
|
2981
3519
|
if (need_check && i > i_max) {
|
2982
3520
|
continue;
|
2983
3521
|
}
|
2984
3522
|
|
2985
|
-
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/
|
3523
|
+
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
2986
3524
|
}
|
2987
3525
|
}
|
2988
3526
|
}
|
@@ -2992,17 +3530,17 @@ struct mmq_args {
|
|
2992
3530
|
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
|
2993
3531
|
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
|
2994
3532
|
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
|
2995
|
-
bool use_stream_k;
|
3533
|
+
bool use_stream_k; int64_t ncols_max;
|
2996
3534
|
};
|
2997
3535
|
|
2998
3536
|
template<ggml_type type>
|
2999
|
-
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
|
3537
|
+
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
|
3000
3538
|
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
3001
3539
|
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
3002
3540
|
const size_t nbs_ids = mmq_x*sizeof(int);
|
3003
|
-
const size_t nbs_x =
|
3541
|
+
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
3004
3542
|
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
3005
|
-
return nbs_ids + nbs_x + GGML_PAD(nbs_y,
|
3543
|
+
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
3006
3544
|
}
|
3007
3545
|
|
3008
3546
|
template <ggml_type type, int mmq_x>
|
@@ -3010,23 +3548,19 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
3010
3548
|
const int id = ggml_cuda_get_device();
|
3011
3549
|
const int cc = ggml_cuda_info().devices[id].cc;
|
3012
3550
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
3551
|
+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
3552
|
+
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
3013
3553
|
const int mmq_y = get_mmq_y_host(cc);
|
3014
3554
|
|
3015
|
-
const dim3 block_dims(
|
3555
|
+
const dim3 block_dims(warp_size, nwarps, 1);
|
3016
3556
|
|
3017
|
-
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
3557
|
+
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
|
3018
3558
|
|
3019
|
-
|
3020
|
-
|
3021
|
-
if (!shared_memory_limit_raised[id]) {
|
3022
|
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
3023
|
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
3024
|
-
shared_memory_limit_raised[id] = true;
|
3025
|
-
}
|
3026
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
3559
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
|
3560
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
|
3027
3561
|
|
3028
3562
|
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
3029
|
-
const int ntx = (args.
|
3563
|
+
const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
|
3030
3564
|
const int ntzw = args.nchannels_y * args.nsamples_y;
|
3031
3565
|
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
|
3032
3566
|
|
@@ -3038,18 +3572,20 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
3038
3572
|
if (!args.use_stream_k) {
|
3039
3573
|
if (args.nrows_x % mmq_y == 0) {
|
3040
3574
|
constexpr bool need_check = false;
|
3041
|
-
mul_mat_q<type, mmq_x,
|
3575
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
3042
3576
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
3043
3577
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
3044
3578
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
3045
|
-
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst
|
3579
|
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
3580
|
+
args.ncols_max);
|
3046
3581
|
} else {
|
3047
3582
|
constexpr bool need_check = true;
|
3048
|
-
mul_mat_q<type, mmq_x,
|
3583
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
3049
3584
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
3050
3585
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
3051
3586
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
3052
|
-
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst
|
3587
|
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
3588
|
+
args.ncols_max);
|
3053
3589
|
}
|
3054
3590
|
return;
|
3055
3591
|
}
|
@@ -3065,44 +3601,48 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
3065
3601
|
|
3066
3602
|
if (args.nrows_x % mmq_y == 0) {
|
3067
3603
|
constexpr bool need_check = false;
|
3068
|
-
|
3069
|
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
3604
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
3070
3605
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
3071
3606
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
3072
3607
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
3073
|
-
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst
|
3608
|
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
3609
|
+
args.ncols_max);
|
3074
3610
|
|
3075
3611
|
if (!fixup_needed) {
|
3076
3612
|
return;
|
3077
3613
|
}
|
3078
3614
|
|
3079
|
-
mul_mat_q_stream_k_fixup<type, mmq_x,
|
3615
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
3080
3616
|
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
3081
|
-
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst
|
3617
|
+
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
3618
|
+
args.ncols_max);
|
3082
3619
|
} else {
|
3083
3620
|
constexpr bool need_check = true;
|
3084
|
-
|
3085
|
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
3621
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
3086
3622
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
3087
3623
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
3088
3624
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
3089
|
-
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst
|
3625
|
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
3626
|
+
args.ncols_max);
|
3090
3627
|
|
3091
3628
|
if (!fixup_needed) {
|
3092
3629
|
return;
|
3093
3630
|
}
|
3094
3631
|
|
3095
|
-
mul_mat_q_stream_k_fixup<type, mmq_x,
|
3632
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
3096
3633
|
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
3097
|
-
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst
|
3634
|
+
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
3635
|
+
args.ncols_max);
|
3098
3636
|
}
|
3099
3637
|
}
|
3100
3638
|
|
3101
3639
|
template <ggml_type type>
|
3102
3640
|
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
3103
|
-
const int id
|
3104
|
-
const int cc
|
3105
|
-
const size_t smpbo
|
3641
|
+
const int id = ggml_cuda_get_device();
|
3642
|
+
const int cc = ggml_cuda_info().devices[id].cc;
|
3643
|
+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
3644
|
+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
3645
|
+
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
3106
3646
|
|
3107
3647
|
const int mmq_x_max = get_mmq_x_max_host(cc);
|
3108
3648
|
const int mmq_y = get_mmq_y_host(cc);
|
@@ -3113,11 +3653,11 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
|
3113
3653
|
for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
|
3114
3654
|
const int granularity = mmq_get_granularity_host(mmq_x, cc);
|
3115
3655
|
|
3116
|
-
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
|
3656
|
+
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
|
3117
3657
|
continue;
|
3118
3658
|
}
|
3119
3659
|
|
3120
|
-
const int ntiles_x = (args.
|
3660
|
+
const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
|
3121
3661
|
|
3122
3662
|
if (ntiles_x < ntiles_x_best) {
|
3123
3663
|
mmq_x_best = mmq_x;
|
@@ -3189,6 +3729,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
|
3189
3729
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
3190
3730
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
3191
3731
|
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
3732
|
+
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
3192
3733
|
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
3193
3734
|
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
3194
3735
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|