whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
|
|
|
11
11
|
|
|
12
12
|
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
|
13
13
|
#define MMQ_ITER_K 256
|
|
14
|
+
#define MMQ_ITER_K_MXFP4_FP4 512
|
|
14
15
|
#define MMQ_NWARPS 8
|
|
15
16
|
|
|
16
17
|
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
|
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
|
|
|
44
45
|
};
|
|
45
46
|
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
|
46
47
|
};
|
|
48
|
+
|
|
49
|
+
struct block_fp4_mmq {
|
|
50
|
+
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
|
|
51
|
+
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
|
|
52
|
+
};
|
|
53
|
+
|
|
47
54
|
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
|
48
55
|
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
|
56
|
+
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
|
|
49
57
|
|
|
50
58
|
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
51
59
|
switch (type_x) {
|
|
@@ -92,7 +100,7 @@ struct tile_x_sizes {
|
|
|
92
100
|
};
|
|
93
101
|
|
|
94
102
|
static int get_mmq_x_max_host(const int cc) {
|
|
95
|
-
return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
|
|
103
|
+
return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
|
|
96
104
|
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
|
97
105
|
#ifdef GGML_CUDA_FORCE_MMQ
|
|
98
106
|
128 : 64;
|
|
@@ -102,7 +110,7 @@ static int get_mmq_x_max_host(const int cc) {
|
|
|
102
110
|
}
|
|
103
111
|
|
|
104
112
|
static constexpr __device__ int get_mmq_x_max_device() {
|
|
105
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
113
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
106
114
|
return 128;
|
|
107
115
|
#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
108
116
|
|
|
@@ -121,7 +129,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|
|
121
129
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
122
130
|
|
|
123
131
|
#endif // defined(GGML_USE_HIP)
|
|
124
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
132
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
125
133
|
}
|
|
126
134
|
|
|
127
135
|
static int get_mmq_y_host(const int cc) {
|
|
@@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
|
|
|
129
137
|
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
|
|
130
138
|
}
|
|
131
139
|
|
|
140
|
+
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
|
141
|
+
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
142
|
+
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
|
|
143
|
+
#else
|
|
144
|
+
return MMQ_ITER_K;
|
|
145
|
+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
146
|
+
}
|
|
147
|
+
|
|
132
148
|
static constexpr __device__ int get_mmq_y_device() {
|
|
133
149
|
#if defined(GGML_USE_HIP)
|
|
134
150
|
#if defined(RDNA1)
|
|
@@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
|
191
207
|
}
|
|
192
208
|
|
|
193
209
|
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
210
|
+
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
|
|
194
211
|
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
195
212
|
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
|
196
213
|
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
|
@@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
|
|
201
218
|
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
|
|
202
219
|
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
|
203
220
|
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|
221
|
+
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
|
222
|
+
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
|
|
204
223
|
|
|
205
224
|
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
206
225
|
switch (type) {
|
|
@@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
209
228
|
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
210
229
|
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
211
230
|
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
231
|
+
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
|
212
232
|
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
213
233
|
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
|
214
234
|
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
|
@@ -228,10 +248,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
228
248
|
}
|
|
229
249
|
|
|
230
250
|
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
|
|
231
|
-
#define MMQ_TILE_Y_K
|
|
251
|
+
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
|
|
252
|
+
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
|
|
232
253
|
|
|
233
254
|
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
|
234
|
-
if (amd_mfma_available(cc)) {
|
|
255
|
+
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
|
|
235
256
|
return mmq_x >= 128 ? 32 : 16;
|
|
236
257
|
} else if (turing_mma_available(cc) && mmq_x >= 48) {
|
|
237
258
|
return 16;
|
|
@@ -240,7 +261,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
|
|
240
261
|
}
|
|
241
262
|
}
|
|
242
263
|
|
|
243
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
264
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
244
265
|
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
|
245
266
|
return mmq_x >= 128 ? 32 : 16;
|
|
246
267
|
}
|
|
@@ -265,7 +286,7 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
|
|
|
265
286
|
#endif // (GGML_USE_HIP)
|
|
266
287
|
|
|
267
288
|
static constexpr __device__ int mmq_get_nwarps_device() {
|
|
268
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
289
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
269
290
|
return 8;
|
|
270
291
|
#else
|
|
271
292
|
return 256/ggml_cuda_get_physical_warp_size();
|
|
@@ -279,14 +300,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
279
300
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
280
301
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
281
302
|
|
|
282
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
303
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
283
304
|
int * x_qs = (int *) x_tile;
|
|
284
305
|
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
285
306
|
#else
|
|
286
307
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
|
287
308
|
int * x_qs = (int *) x_tile;
|
|
288
309
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
289
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
310
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
290
311
|
|
|
291
312
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
|
|
292
313
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -305,7 +326,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
305
326
|
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
|
306
327
|
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
|
307
328
|
|
|
308
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
329
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
309
330
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
|
|
310
331
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
|
|
311
332
|
#else
|
|
@@ -327,11 +348,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
327
348
|
|
|
328
349
|
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
|
|
329
350
|
|
|
330
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
351
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
331
352
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
332
353
|
#else
|
|
333
354
|
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
|
334
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
355
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
335
356
|
}
|
|
336
357
|
}
|
|
337
358
|
|
|
@@ -382,14 +403,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
382
403
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
383
404
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
384
405
|
|
|
385
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
406
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
386
407
|
int * x_qs = (int *) x_tile;
|
|
387
408
|
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
388
409
|
#else
|
|
389
410
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
|
390
411
|
int * x_qs = (int *) x_tile;
|
|
391
412
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
392
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
413
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
393
414
|
|
|
394
415
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
|
|
395
416
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -408,12 +429,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
408
429
|
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
|
409
430
|
const int qs0 = get_int_b4(bxi->qs, kqsx);
|
|
410
431
|
|
|
411
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
432
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
412
433
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
|
413
434
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
|
|
414
435
|
#else
|
|
415
436
|
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
|
416
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
437
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
417
438
|
}
|
|
418
439
|
|
|
419
440
|
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
|
|
@@ -430,11 +451,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
430
451
|
|
|
431
452
|
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
|
|
432
453
|
|
|
433
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
454
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
434
455
|
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
|
435
456
|
#else
|
|
436
457
|
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
|
|
437
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
458
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
438
459
|
}
|
|
439
460
|
}
|
|
440
461
|
|
|
@@ -485,14 +506,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
485
506
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
486
507
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
487
508
|
|
|
488
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
509
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
489
510
|
int * x_qs = (int *) x_tile;
|
|
490
511
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
491
512
|
#else
|
|
492
513
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
|
493
514
|
int * x_qs = (int *) x_tile;
|
|
494
515
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
495
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
516
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
496
517
|
|
|
497
518
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
|
|
498
519
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -527,13 +548,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
527
548
|
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
|
528
549
|
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
|
529
550
|
|
|
530
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
551
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
531
552
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
|
532
553
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
|
533
554
|
#else
|
|
534
555
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
|
535
556
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
|
536
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
557
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
537
558
|
}
|
|
538
559
|
|
|
539
560
|
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
|
|
@@ -550,11 +571,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
550
571
|
|
|
551
572
|
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
|
552
573
|
|
|
553
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
574
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
554
575
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
555
576
|
#else
|
|
556
577
|
x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
|
557
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
578
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
558
579
|
}
|
|
559
580
|
}
|
|
560
581
|
|
|
@@ -563,14 +584,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
563
584
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
564
585
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
565
586
|
|
|
566
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
587
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
567
588
|
int * x_qs = (int *) x_tile;
|
|
568
589
|
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
569
590
|
#else
|
|
570
591
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
|
571
592
|
int * x_qs = (int *) x_tile;
|
|
572
593
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
573
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
594
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
574
595
|
|
|
575
596
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
|
|
576
597
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -603,13 +624,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
603
624
|
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
|
604
625
|
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
|
605
626
|
|
|
606
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
627
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
607
628
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
|
608
629
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
|
609
630
|
#else
|
|
610
631
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
|
611
632
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
|
612
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
633
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
613
634
|
}
|
|
614
635
|
|
|
615
636
|
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
|
|
@@ -626,11 +647,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
626
647
|
|
|
627
648
|
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
|
628
649
|
|
|
629
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
650
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
630
651
|
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
|
631
652
|
#else
|
|
632
653
|
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
|
633
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
654
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
634
655
|
}
|
|
635
656
|
}
|
|
636
657
|
|
|
@@ -639,14 +660,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
639
660
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
640
661
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
641
662
|
|
|
642
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
663
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
643
664
|
int * x_qs = (int *) x_tile;
|
|
644
665
|
float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
|
|
645
666
|
#else
|
|
646
667
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
|
647
668
|
int * x_qs = (int *) x_tile;
|
|
648
669
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
649
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
670
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
650
671
|
|
|
651
672
|
// MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
|
|
652
673
|
constexpr int threads_per_row = 32;
|
|
@@ -665,13 +686,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
665
686
|
|
|
666
687
|
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
|
667
688
|
|
|
668
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
689
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
669
690
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
|
670
691
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
|
671
692
|
#else
|
|
672
693
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
|
673
694
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
|
674
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
695
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
675
696
|
}
|
|
676
697
|
|
|
677
698
|
constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
|
|
@@ -688,11 +709,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
688
709
|
|
|
689
710
|
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
|
690
711
|
|
|
691
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
712
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
692
713
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
693
714
|
#else
|
|
694
715
|
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
|
695
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
716
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
696
717
|
}
|
|
697
718
|
}
|
|
698
719
|
|
|
@@ -701,14 +722,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
701
722
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
702
723
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
703
724
|
|
|
704
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
725
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
705
726
|
int * x_qs = (int *) x_tile;
|
|
706
727
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
707
728
|
#else
|
|
708
729
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
|
|
709
730
|
int * x_qs = (int *) x_tile;
|
|
710
731
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
711
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
732
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
712
733
|
|
|
713
734
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
|
|
714
735
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -730,13 +751,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
730
751
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
|
731
752
|
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
|
|
732
753
|
|
|
733
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
754
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
734
755
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
|
|
735
756
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
|
|
736
757
|
#else
|
|
737
758
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
|
738
759
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
|
|
739
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
760
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
740
761
|
}
|
|
741
762
|
|
|
742
763
|
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
|
|
@@ -753,11 +774,55 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
753
774
|
|
|
754
775
|
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
|
|
755
776
|
|
|
756
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
777
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
757
778
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
|
758
779
|
#else
|
|
759
780
|
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
|
760
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
781
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
782
|
+
}
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
template <int mmq_y, bool need_check>
|
|
786
|
+
static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
|
|
787
|
+
int * __restrict__ x_tile,
|
|
788
|
+
const int kbx0,
|
|
789
|
+
const int i_max,
|
|
790
|
+
const int stride) {
|
|
791
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
792
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
793
|
+
|
|
794
|
+
int * x_qs = (int *) x_tile;
|
|
795
|
+
uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
796
|
+
|
|
797
|
+
const int txi = threadIdx.x;
|
|
798
|
+
|
|
799
|
+
constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
|
|
800
|
+
|
|
801
|
+
constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
|
|
802
|
+
constexpr int rows_per_warp = warp_size / threads_per_row;
|
|
803
|
+
const int kbx = txi % threads_per_row;
|
|
804
|
+
const int row_in_warp = txi / threads_per_row;
|
|
805
|
+
|
|
806
|
+
#pragma unroll
|
|
807
|
+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
|
808
|
+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
|
809
|
+
|
|
810
|
+
if constexpr (need_check) {
|
|
811
|
+
i = min(i, i_max);
|
|
812
|
+
}
|
|
813
|
+
|
|
814
|
+
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
|
|
815
|
+
|
|
816
|
+
// quantize_mxfp4_mmq permutes nibbles to match the quantized format
|
|
817
|
+
const int k0 = kbx * 4;
|
|
818
|
+
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
|
|
819
|
+
|
|
820
|
+
// Load E8M0 scales: pack 2 consecutive scales into one uint32
|
|
821
|
+
if (kbx % 2 == 0) {
|
|
822
|
+
uint32_t e = bxi->e;
|
|
823
|
+
e |= ((bxi + 1)->e << 8);
|
|
824
|
+
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
|
|
825
|
+
}
|
|
761
826
|
}
|
|
762
827
|
}
|
|
763
828
|
|
|
@@ -796,10 +861,11 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
|
796
861
|
template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
|
|
797
862
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
798
863
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
799
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
800
|
-
|
|
801
|
-
typedef tile<16, 8, int>
|
|
802
|
-
typedef tile<16,
|
|
864
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
865
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
866
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
867
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
868
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
803
869
|
|
|
804
870
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
805
871
|
constexpr int rows_per_warp = granularity;
|
|
@@ -927,7 +993,79 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
927
993
|
}
|
|
928
994
|
}
|
|
929
995
|
}
|
|
930
|
-
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
996
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
997
|
+
}
|
|
998
|
+
|
|
999
|
+
template <int mmq_x, int mmq_y>
|
|
1000
|
+
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
|
1001
|
+
const int * __restrict__ y,
|
|
1002
|
+
float * __restrict__ sum,
|
|
1003
|
+
const int k00) {
|
|
1004
|
+
typedef tile<16, 8, int> tile_A;
|
|
1005
|
+
typedef tile<8, 8, int> tile_B;
|
|
1006
|
+
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
|
1007
|
+
|
|
1008
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1009
|
+
constexpr int rows_per_warp = 2 * granularity;
|
|
1010
|
+
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
|
1011
|
+
|
|
1012
|
+
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
|
1013
|
+
|
|
1014
|
+
// Match layout from load_tiles_mxfp4_fp4
|
|
1015
|
+
const int * x_qs = (const int *) x;
|
|
1016
|
+
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
1017
|
+
const int * y_qs = (const int *) y + 4;
|
|
1018
|
+
const uint32_t * y_sc = (const uint32_t *) y;
|
|
1019
|
+
|
|
1020
|
+
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
|
1021
|
+
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
1022
|
+
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
1023
|
+
|
|
1024
|
+
// Block scale
|
|
1025
|
+
// Each thread has to point to a 4 byte scale value
|
|
1026
|
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
|
1027
|
+
|
|
1028
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1029
|
+
|
|
1030
|
+
#pragma unroll
|
|
1031
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1032
|
+
#pragma unroll
|
|
1033
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
1034
|
+
const int k0 = k00 + k01;
|
|
1035
|
+
|
|
1036
|
+
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
|
1037
|
+
MMQ_MMA_TILE_X_K_FP4);
|
|
1038
|
+
|
|
1039
|
+
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
|
1040
|
+
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
|
1041
|
+
scaleA[n][k01 / (2 * QI_MXFP4)] =
|
|
1042
|
+
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
|
1043
|
+
}
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
#pragma unroll
|
|
1047
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
|
1048
|
+
#pragma unroll
|
|
1049
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
1050
|
+
tile_B B;
|
|
1051
|
+
uint32_t scaleB; // 2xN scales
|
|
1052
|
+
|
|
1053
|
+
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
|
1054
|
+
|
|
1055
|
+
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
|
1056
|
+
|
|
1057
|
+
#pragma unroll
|
|
1058
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1059
|
+
tile_C C;
|
|
1060
|
+
|
|
1061
|
+
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
|
1062
|
+
#pragma unroll
|
|
1063
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1064
|
+
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
|
1065
|
+
}
|
|
1066
|
+
}
|
|
1067
|
+
}
|
|
1068
|
+
}
|
|
931
1069
|
}
|
|
932
1070
|
|
|
933
1071
|
template <int mmq_x, int mmq_y>
|
|
@@ -965,10 +1103,11 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
|
965
1103
|
template <int mmq_x, int mmq_y>
|
|
966
1104
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
967
1105
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
968
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
969
|
-
|
|
970
|
-
typedef tile<16, 8, int>
|
|
971
|
-
typedef tile<16,
|
|
1106
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1107
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1108
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1109
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1110
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
972
1111
|
|
|
973
1112
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
974
1113
|
constexpr int rows_per_warp = granularity;
|
|
@@ -1087,7 +1226,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
1087
1226
|
}
|
|
1088
1227
|
}
|
|
1089
1228
|
}
|
|
1090
|
-
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
1229
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1091
1230
|
}
|
|
1092
1231
|
|
|
1093
1232
|
// Used for Q3_K, IQ2_S, and IQ2_XS
|
|
@@ -1130,10 +1269,11 @@ template <int mmq_x, int mmq_y>
|
|
|
1130
1269
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
1131
1270
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1132
1271
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
1133
|
-
|
|
1134
|
-
typedef tile<16, 8, int>
|
|
1135
|
-
typedef tile<16,
|
|
1136
|
-
typedef tile<
|
|
1272
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1273
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1274
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1275
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1276
|
+
typedef tile<64, 2, int, input_layout> tile_load;
|
|
1137
1277
|
|
|
1138
1278
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1139
1279
|
constexpr int rows_per_warp = granularity;
|
|
@@ -1170,6 +1310,55 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
1170
1310
|
tile_C C;
|
|
1171
1311
|
mma(C, A[n], B[0]);
|
|
1172
1312
|
|
|
1313
|
+
#pragma unroll
|
|
1314
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1315
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1316
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
|
|
1317
|
+
}
|
|
1318
|
+
}
|
|
1319
|
+
}
|
|
1320
|
+
}
|
|
1321
|
+
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
1322
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1323
|
+
typedef tile<16, 4, int, input_layout> tile_A;
|
|
1324
|
+
typedef tile<16, 4, int, input_layout> tile_B;
|
|
1325
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1326
|
+
|
|
1327
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1328
|
+
constexpr int rows_per_warp = granularity;
|
|
1329
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1330
|
+
|
|
1331
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1332
|
+
|
|
1333
|
+
const int * x_qs = (const int *) x;
|
|
1334
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
1335
|
+
const int * y_qs = (const int *) y + 4;
|
|
1336
|
+
const float * y_df = (const float *) y;
|
|
1337
|
+
|
|
1338
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1339
|
+
|
|
1340
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1341
|
+
const int k0 = k00 + k01;
|
|
1342
|
+
|
|
1343
|
+
tile_A A[ntx];
|
|
1344
|
+
#pragma unroll
|
|
1345
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1346
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
#pragma unroll
|
|
1350
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1351
|
+
tile_B B;
|
|
1352
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1353
|
+
|
|
1354
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1355
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
1356
|
+
|
|
1357
|
+
#pragma unroll
|
|
1358
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1359
|
+
tile_C C;
|
|
1360
|
+
mma(C, A[n], B);
|
|
1361
|
+
|
|
1173
1362
|
#pragma unroll
|
|
1174
1363
|
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1175
1364
|
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
@@ -1257,21 +1446,21 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
1257
1446
|
#else
|
|
1258
1447
|
GGML_UNUSED_VARS(x, y, sum, k00);
|
|
1259
1448
|
NO_DEVICE_CODE;
|
|
1260
|
-
#endif // AMD_MFMA_AVAILABLE
|
|
1449
|
+
#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
|
|
1261
1450
|
}
|
|
1262
1451
|
|
|
1263
1452
|
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
|
1264
1453
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1265
1454
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1266
1455
|
|
|
1267
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1456
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1268
1457
|
int * x_qs = (int *) x_tile;
|
|
1269
1458
|
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
1270
1459
|
#else
|
|
1271
1460
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
|
1272
1461
|
int * x_qs = (int *) x_tile;
|
|
1273
1462
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
1274
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1463
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1275
1464
|
|
|
1276
1465
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
|
|
1277
1466
|
constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
|
|
@@ -1295,11 +1484,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1295
1484
|
|
|
1296
1485
|
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
|
1297
1486
|
|
|
1298
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1487
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1299
1488
|
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
|
1300
1489
|
#else
|
|
1301
1490
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
|
1302
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1491
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1303
1492
|
}
|
|
1304
1493
|
|
|
1305
1494
|
const int sc_m = bxi->scales[kqsx];
|
|
@@ -1310,11 +1499,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1310
1499
|
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
|
1311
1500
|
#endif // FAST_FP16_AVAILABLE
|
|
1312
1501
|
|
|
1313
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1502
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1314
1503
|
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
|
1315
1504
|
#else
|
|
1316
1505
|
x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
|
|
1317
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1506
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1318
1507
|
}
|
|
1319
1508
|
}
|
|
1320
1509
|
|
|
@@ -1387,10 +1576,11 @@ template <int mmq_x, int mmq_y>
|
|
|
1387
1576
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1388
1577
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1389
1578
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
1390
|
-
|
|
1391
|
-
typedef tile<16, 8, int>
|
|
1392
|
-
typedef tile<16,
|
|
1393
|
-
typedef tile<
|
|
1579
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1580
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
1581
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
1582
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1583
|
+
typedef tile<64, 2, int, input_layout> tile_load;
|
|
1394
1584
|
|
|
1395
1585
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1396
1586
|
constexpr int rows_per_warp = granularity;
|
|
@@ -1438,6 +1628,74 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1438
1628
|
tile_C Cd;
|
|
1439
1629
|
mma(Cd, A[n], B[0]);
|
|
1440
1630
|
|
|
1631
|
+
#pragma unroll
|
|
1632
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1633
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1634
|
+
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
|
|
1635
|
+
float tmp = Cd.x[l]*dm.x;
|
|
1636
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1637
|
+
tmp -= Cm.x[l]*dm.y;
|
|
1638
|
+
}
|
|
1639
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
|
|
1640
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
|
|
1641
|
+
}
|
|
1642
|
+
}
|
|
1643
|
+
}
|
|
1644
|
+
}
|
|
1645
|
+
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
1646
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
1647
|
+
typedef tile<16, 4, int, input_layout> tile_A;
|
|
1648
|
+
typedef tile<16, 4, int, input_layout> tile_B;
|
|
1649
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
1650
|
+
|
|
1651
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1652
|
+
constexpr int rows_per_warp = granularity;
|
|
1653
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1654
|
+
|
|
1655
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1656
|
+
|
|
1657
|
+
const int * x_qs = (const int *) x;
|
|
1658
|
+
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
1659
|
+
const int * y_qs = (const int *) y + 4;
|
|
1660
|
+
const half2 * y_ds = (const half2 *) y;
|
|
1661
|
+
|
|
1662
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1663
|
+
|
|
1664
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1665
|
+
const int k0 = k00 + k01;
|
|
1666
|
+
|
|
1667
|
+
tile_A A[ntx];
|
|
1668
|
+
#pragma unroll
|
|
1669
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1670
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
1671
|
+
}
|
|
1672
|
+
|
|
1673
|
+
#pragma unroll
|
|
1674
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1675
|
+
tile_B B;
|
|
1676
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1677
|
+
|
|
1678
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1679
|
+
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
|
|
1680
|
+
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
|
|
1681
|
+
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
|
|
1682
|
+
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
|
|
1683
|
+
|
|
1684
|
+
tile_C Cm;
|
|
1685
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1686
|
+
tile_A A1;
|
|
1687
|
+
#pragma unroll
|
|
1688
|
+
for (int l = 0; l < tile_A::ne; ++l) {
|
|
1689
|
+
A1.x[l] = 0x01010101;
|
|
1690
|
+
}
|
|
1691
|
+
mma(Cm, A1, B);
|
|
1692
|
+
}
|
|
1693
|
+
|
|
1694
|
+
#pragma unroll
|
|
1695
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1696
|
+
tile_C Cd;
|
|
1697
|
+
mma(Cd, A[n], B);
|
|
1698
|
+
|
|
1441
1699
|
#pragma unroll
|
|
1442
1700
|
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1443
1701
|
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
@@ -1574,7 +1832,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1574
1832
|
#else
|
|
1575
1833
|
GGML_UNUSED_VARS(x, y, sum, k00);
|
|
1576
1834
|
NO_DEVICE_CODE;
|
|
1577
|
-
#endif // AMD_MFMA_AVAILABLE
|
|
1835
|
+
#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
|
|
1578
1836
|
}
|
|
1579
1837
|
|
|
1580
1838
|
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
|
@@ -1582,7 +1840,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1582
1840
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1583
1841
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1584
1842
|
|
|
1585
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1843
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1586
1844
|
int * x_qs = (int *) x_tile;
|
|
1587
1845
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1588
1846
|
#else
|
|
@@ -1618,11 +1876,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1618
1876
|
|
|
1619
1877
|
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
|
1620
1878
|
|
|
1621
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1879
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1622
1880
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
|
1623
1881
|
#else
|
|
1624
1882
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
|
1625
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1883
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1626
1884
|
}
|
|
1627
1885
|
}
|
|
1628
1886
|
|
|
@@ -1649,7 +1907,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1649
1907
|
|
|
1650
1908
|
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
|
1651
1909
|
|
|
1652
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1910
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1653
1911
|
const int8_t * sc8 = (const int8_t *) ≻
|
|
1654
1912
|
const float d = bxi->d;
|
|
1655
1913
|
|
|
@@ -1659,10 +1917,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1659
1917
|
}
|
|
1660
1918
|
#else
|
|
1661
1919
|
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
|
|
1662
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1920
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1663
1921
|
}
|
|
1664
1922
|
|
|
1665
|
-
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
1923
|
+
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
|
|
1666
1924
|
#pragma unroll
|
|
1667
1925
|
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
1668
1926
|
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
@@ -1675,7 +1933,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1675
1933
|
|
|
1676
1934
|
x_df[i] = bxi->d;
|
|
1677
1935
|
}
|
|
1678
|
-
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
1936
|
+
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
|
|
1679
1937
|
}
|
|
1680
1938
|
|
|
1681
1939
|
template <int mmq_x, int mmq_y>
|
|
@@ -1728,7 +1986,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1728
1986
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1729
1987
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1730
1988
|
|
|
1731
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1989
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1732
1990
|
int * x_qs = (int *) x_tile;
|
|
1733
1991
|
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
1734
1992
|
#else
|
|
@@ -1736,7 +1994,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1736
1994
|
int * x_qs = (int *) x_tile;
|
|
1737
1995
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
1738
1996
|
int * x_sc = (int *) (x_dm + txs.dm);
|
|
1739
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1997
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1740
1998
|
|
|
1741
1999
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
|
|
1742
2000
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -1753,19 +2011,19 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1753
2011
|
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
|
1754
2012
|
const int qs0 = get_int_b4(bxi->qs, txi);
|
|
1755
2013
|
|
|
1756
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2014
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1757
2015
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
|
1758
2016
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
|
|
1759
2017
|
#else
|
|
1760
2018
|
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
|
1761
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2019
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1762
2020
|
}
|
|
1763
2021
|
|
|
1764
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2022
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1765
2023
|
constexpr int rows_per_warp = warp_size / 2;
|
|
1766
2024
|
#pragma unroll
|
|
1767
2025
|
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
1768
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
2026
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1769
2027
|
// Need if on AMD instead of % because warp_size == 64
|
|
1770
2028
|
// This causes double work and throughput loss (MI300X)
|
|
1771
2029
|
// H100 loses about 100 t/s with 'if' condition over '%'
|
|
@@ -1774,7 +2032,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1774
2032
|
#else
|
|
1775
2033
|
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
|
|
1776
2034
|
{
|
|
1777
|
-
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
2035
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1778
2036
|
if (need_check) {
|
|
1779
2037
|
i = min(i, i_max);
|
|
1780
2038
|
}
|
|
@@ -1829,7 +2087,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1829
2087
|
|
|
1830
2088
|
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
|
1831
2089
|
}
|
|
1832
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2090
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1833
2091
|
}
|
|
1834
2092
|
|
|
1835
2093
|
template <int mmq_x, int mmq_y>
|
|
@@ -1872,7 +2130,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1872
2130
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1873
2131
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1874
2132
|
|
|
1875
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2133
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1876
2134
|
int * x_qs = (int *) x_tile;
|
|
1877
2135
|
half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1878
2136
|
#else
|
|
@@ -1908,16 +2166,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1908
2166
|
const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
|
|
1909
2167
|
const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
|
|
1910
2168
|
|
|
1911
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2169
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1912
2170
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
|
|
1913
2171
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
|
|
1914
2172
|
#else
|
|
1915
2173
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
|
|
1916
2174
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
|
|
1917
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2175
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1918
2176
|
}
|
|
1919
2177
|
|
|
1920
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2178
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1921
2179
|
constexpr int rows_per_warp = warp_size / 2;
|
|
1922
2180
|
#pragma unroll
|
|
1923
2181
|
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
@@ -1930,7 +2188,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1930
2188
|
#else
|
|
1931
2189
|
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
|
|
1932
2190
|
{
|
|
1933
|
-
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
2191
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1934
2192
|
if (need_check) {
|
|
1935
2193
|
i = min(i, i_max);
|
|
1936
2194
|
}
|
|
@@ -1986,7 +2244,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
1986
2244
|
|
|
1987
2245
|
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
|
1988
2246
|
}
|
|
1989
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2247
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
1990
2248
|
}
|
|
1991
2249
|
|
|
1992
2250
|
template <int mmq_x, int mmq_y>
|
|
@@ -2029,7 +2287,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2029
2287
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2030
2288
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2031
2289
|
|
|
2032
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2290
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2033
2291
|
int * x_qs = (int *) x_tile;
|
|
2034
2292
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2035
2293
|
int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
|
|
@@ -2038,7 +2296,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2038
2296
|
int * x_qs = (int *) x_tile;
|
|
2039
2297
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2040
2298
|
int * x_sc = (int *) (x_df + txs.dm);
|
|
2041
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2299
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2042
2300
|
|
|
2043
2301
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
|
|
2044
2302
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2065,13 +2323,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2065
2323
|
const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
|
|
2066
2324
|
const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
|
|
2067
2325
|
|
|
2068
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2326
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2069
2327
|
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
|
2070
2328
|
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
|
2071
2329
|
#else
|
|
2072
2330
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
|
2073
2331
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
|
2074
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2332
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2075
2333
|
}
|
|
2076
2334
|
|
|
2077
2335
|
#pragma unroll
|
|
@@ -2084,11 +2342,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2084
2342
|
|
|
2085
2343
|
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
|
2086
2344
|
|
|
2087
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2345
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2088
2346
|
x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
|
|
2089
2347
|
#else
|
|
2090
2348
|
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
|
|
2091
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2349
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2092
2350
|
}
|
|
2093
2351
|
|
|
2094
2352
|
constexpr int rows_per_warp = warp_size / 4;
|
|
@@ -2102,11 +2360,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2102
2360
|
|
|
2103
2361
|
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
|
|
2104
2362
|
|
|
2105
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2363
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2106
2364
|
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
|
|
2107
2365
|
#else
|
|
2108
2366
|
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
|
|
2109
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2367
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2110
2368
|
}
|
|
2111
2369
|
}
|
|
2112
2370
|
|
|
@@ -2149,10 +2407,11 @@ template <int mmq_x, int mmq_y>
|
|
|
2149
2407
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
2150
2408
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
2151
2409
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
2152
|
-
|
|
2153
|
-
typedef tile<16, 8, int>
|
|
2154
|
-
typedef tile<16,
|
|
2155
|
-
typedef tile<
|
|
2410
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
2411
|
+
typedef tile<16, 8, int, input_layout> tile_A;
|
|
2412
|
+
typedef tile<16, 8, int, input_layout> tile_B;
|
|
2413
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
2414
|
+
typedef tile<64, 2, int, input_layout> tile_load;
|
|
2156
2415
|
|
|
2157
2416
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2158
2417
|
constexpr int rows_per_warp = granularity;
|
|
@@ -2190,6 +2449,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
2190
2449
|
tile_C C;
|
|
2191
2450
|
mma(C, A[n], B[0]);
|
|
2192
2451
|
|
|
2452
|
+
#pragma unroll
|
|
2453
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
2454
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
2455
|
+
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
|
|
2456
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
|
|
2457
|
+
}
|
|
2458
|
+
}
|
|
2459
|
+
}
|
|
2460
|
+
}
|
|
2461
|
+
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
|
2462
|
+
constexpr data_layout input_layout = get_input_data_layout();
|
|
2463
|
+
typedef tile<16, 4, int, input_layout> tile_A;
|
|
2464
|
+
typedef tile<16, 4, int, input_layout> tile_B;
|
|
2465
|
+
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
2466
|
+
|
|
2467
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2468
|
+
constexpr int rows_per_warp = granularity;
|
|
2469
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2470
|
+
|
|
2471
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
2472
|
+
|
|
2473
|
+
const int * x_qs = (const int *) x;
|
|
2474
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
2475
|
+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
2476
|
+
const int * y_qs = (const int *) y + 4;
|
|
2477
|
+
const float * y_df = (const float *) y;
|
|
2478
|
+
|
|
2479
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
2480
|
+
|
|
2481
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
2482
|
+
const int k0 = k00 + k01;
|
|
2483
|
+
|
|
2484
|
+
tile_A A[ntx];
|
|
2485
|
+
#pragma unroll
|
|
2486
|
+
for (int n = 0; n < ntx; ++n) {
|
|
2487
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
2488
|
+
}
|
|
2489
|
+
|
|
2490
|
+
#pragma unroll
|
|
2491
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
2492
|
+
tile_B B;
|
|
2493
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
2494
|
+
|
|
2495
|
+
const int j = j0 + tile_C::get_j(0);
|
|
2496
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
2497
|
+
|
|
2498
|
+
#pragma unroll
|
|
2499
|
+
for (int n = 0; n < ntx; ++n) {
|
|
2500
|
+
tile_C C;
|
|
2501
|
+
mma(C, A[n], B);
|
|
2502
|
+
|
|
2193
2503
|
#pragma unroll
|
|
2194
2504
|
for (int l = 0; l < tile_C::ne; ++l) {
|
|
2195
2505
|
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
@@ -2303,7 +2613,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
2303
2613
|
#else
|
|
2304
2614
|
GGML_UNUSED_VARS(x, y, sum, k00);
|
|
2305
2615
|
NO_DEVICE_CODE;
|
|
2306
|
-
#endif // AMD_MFMA_AVAILABLE
|
|
2616
|
+
#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
|
|
2307
2617
|
}
|
|
2308
2618
|
|
|
2309
2619
|
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
|
@@ -2311,14 +2621,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2311
2621
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2312
2622
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2313
2623
|
|
|
2314
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2624
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2315
2625
|
int * x_qs = (int *) x_tile;
|
|
2316
2626
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2317
2627
|
#else
|
|
2318
2628
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
|
2319
2629
|
int * x_qs = (int *) x_tile;
|
|
2320
2630
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2321
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2631
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2322
2632
|
|
|
2323
2633
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
|
|
2324
2634
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2340,13 +2650,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2340
2650
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
|
2341
2651
|
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
|
2342
2652
|
|
|
2343
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2653
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2344
2654
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
|
2345
2655
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
|
|
2346
2656
|
#else
|
|
2347
2657
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
|
2348
2658
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
|
|
2349
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2659
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2350
2660
|
}
|
|
2351
2661
|
|
|
2352
2662
|
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
|
|
@@ -2363,11 +2673,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2363
2673
|
|
|
2364
2674
|
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
|
2365
2675
|
|
|
2366
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2676
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2367
2677
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
|
2368
2678
|
#else
|
|
2369
2679
|
x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
|
|
2370
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2680
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2371
2681
|
}
|
|
2372
2682
|
}
|
|
2373
2683
|
|
|
@@ -2376,14 +2686,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2376
2686
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2377
2687
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2378
2688
|
|
|
2379
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2689
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2380
2690
|
int * x_qs = (int *) x_tile;
|
|
2381
2691
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2382
2692
|
#else
|
|
2383
2693
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
|
|
2384
2694
|
int * x_qs = (int *) x_tile;
|
|
2385
2695
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2386
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2696
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2387
2697
|
|
|
2388
2698
|
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
|
|
2389
2699
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2414,22 +2724,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2414
2724
|
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
|
2415
2725
|
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
|
2416
2726
|
|
|
2417
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2727
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2418
2728
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
|
2419
2729
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
|
|
2420
2730
|
#else
|
|
2421
2731
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
|
|
2422
2732
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
|
|
2423
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2733
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2424
2734
|
}
|
|
2425
2735
|
|
|
2426
2736
|
const int ls = aux32 >> 28;
|
|
2427
2737
|
const float d = bxi->d;
|
|
2428
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2738
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2429
2739
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
|
2430
2740
|
#else
|
|
2431
2741
|
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
|
2432
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2742
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2433
2743
|
}
|
|
2434
2744
|
}
|
|
2435
2745
|
|
|
@@ -2438,14 +2748,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2438
2748
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2439
2749
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2440
2750
|
|
|
2441
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2751
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2442
2752
|
int * x_qs = (int *) x_tile;
|
|
2443
2753
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2444
2754
|
#else
|
|
2445
2755
|
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
|
2446
2756
|
int * x_qs = (int *) x_tile;
|
|
2447
2757
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2448
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2758
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2449
2759
|
|
|
2450
2760
|
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
|
|
2451
2761
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2472,24 +2782,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2472
2782
|
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
|
2473
2783
|
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
|
2474
2784
|
|
|
2475
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2785
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2476
2786
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2477
2787
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2478
2788
|
#else
|
|
2479
2789
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2480
2790
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2481
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2791
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2482
2792
|
}
|
|
2483
2793
|
|
|
2484
2794
|
const int ls = bxi->scales[kqsx];
|
|
2485
2795
|
const float d = bxi->d;
|
|
2486
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2796
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2487
2797
|
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2488
2798
|
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2489
2799
|
#else
|
|
2490
2800
|
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2491
2801
|
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2492
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2802
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2493
2803
|
}
|
|
2494
2804
|
}
|
|
2495
2805
|
|
|
@@ -2498,15 +2808,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2498
2808
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2499
2809
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2500
2810
|
|
|
2501
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2811
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2502
2812
|
int * x_qs = (int *) x_tile;
|
|
2503
2813
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2504
2814
|
#else
|
|
2505
2815
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
|
|
2506
2816
|
int * x_qs = (int *) x_tile;
|
|
2507
2817
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2508
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2509
|
-
|
|
2818
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2510
2819
|
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
|
|
2511
2820
|
constexpr int nrows = warp_size / threads_per_row;
|
|
2512
2821
|
const int kqsx = threadIdx.x % threads_per_row;
|
|
@@ -2539,24 +2848,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2539
2848
|
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
|
|
2540
2849
|
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
|
|
2541
2850
|
|
|
2542
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2851
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2543
2852
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2544
2853
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2545
2854
|
#else
|
|
2546
2855
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2547
2856
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2548
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2857
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2549
2858
|
}
|
|
2550
2859
|
|
|
2551
2860
|
const int ls = bxi->scales[kqsx];
|
|
2552
2861
|
const float d = bxi->d;
|
|
2553
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2862
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2554
2863
|
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2555
2864
|
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2556
2865
|
#else
|
|
2557
2866
|
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2558
2867
|
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2559
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2868
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2560
2869
|
}
|
|
2561
2870
|
}
|
|
2562
2871
|
|
|
@@ -2565,14 +2874,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2565
2874
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2566
2875
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2567
2876
|
|
|
2568
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2877
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2569
2878
|
int * x_qs = (int *) x_tile;
|
|
2570
2879
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2571
2880
|
#else
|
|
2572
2881
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
|
|
2573
2882
|
int * x_qs = (int *) x_tile;
|
|
2574
2883
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2575
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2884
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2576
2885
|
|
|
2577
2886
|
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
|
|
2578
2887
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2601,22 +2910,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2601
2910
|
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
|
2602
2911
|
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
|
2603
2912
|
|
|
2604
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2913
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2605
2914
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2606
2915
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2607
2916
|
#else
|
|
2608
2917
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2609
2918
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2610
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2919
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2611
2920
|
}
|
|
2612
2921
|
|
|
2613
2922
|
const int ls = aux32 >> 28;
|
|
2614
2923
|
const float d = bxi->d;
|
|
2615
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2924
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2616
2925
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
|
|
2617
2926
|
#else
|
|
2618
2927
|
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
|
|
2619
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2928
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2620
2929
|
}
|
|
2621
2930
|
}
|
|
2622
2931
|
|
|
@@ -2625,14 +2934,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2625
2934
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2626
2935
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2627
2936
|
|
|
2628
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2937
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2629
2938
|
int * x_qs = (int *) x_tile;
|
|
2630
2939
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2631
2940
|
#else
|
|
2632
2941
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
|
2633
2942
|
int * x_qs = (int *) x_tile;
|
|
2634
2943
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2635
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2944
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2636
2945
|
|
|
2637
2946
|
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
|
|
2638
2947
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2668,22 +2977,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2668
2977
|
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
|
2669
2978
|
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
|
2670
2979
|
|
|
2671
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2980
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2672
2981
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
|
|
2673
2982
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
|
|
2674
2983
|
#else
|
|
2675
2984
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
|
|
2676
2985
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
|
|
2677
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2986
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2678
2987
|
}
|
|
2679
2988
|
|
|
2680
2989
|
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
|
|
2681
2990
|
const float d = bxi->d;
|
|
2682
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2991
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2683
2992
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
|
|
2684
2993
|
#else
|
|
2685
2994
|
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
|
|
2686
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
2995
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2687
2996
|
}
|
|
2688
2997
|
}
|
|
2689
2998
|
|
|
@@ -2692,14 +3001,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2692
3001
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2693
3002
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2694
3003
|
|
|
2695
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3004
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2696
3005
|
int * x_qs = (int *) x_tile;
|
|
2697
3006
|
half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2698
3007
|
#else
|
|
2699
3008
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
|
2700
3009
|
int * x_qs = (int *) x_tile;
|
|
2701
3010
|
half2 * x_ds = (half2 *) (x_qs + txs.qs);
|
|
2702
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3011
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2703
3012
|
|
|
2704
3013
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
|
|
2705
3014
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2727,23 +3036,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2727
3036
|
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
|
|
2728
3037
|
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
|
|
2729
3038
|
|
|
2730
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3039
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2731
3040
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
|
|
2732
3041
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
|
|
2733
3042
|
#else
|
|
2734
3043
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
|
|
2735
3044
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
|
|
2736
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3045
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2737
3046
|
}
|
|
2738
3047
|
|
|
2739
3048
|
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
|
|
2740
3049
|
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
|
|
2741
3050
|
|
|
2742
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3051
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2743
3052
|
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
|
|
2744
3053
|
#else
|
|
2745
3054
|
x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
|
|
2746
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3055
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2747
3056
|
}
|
|
2748
3057
|
}
|
|
2749
3058
|
|
|
@@ -2752,14 +3061,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2752
3061
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2753
3062
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2754
3063
|
|
|
2755
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3064
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2756
3065
|
int * x_qs = (int *) x_tile;
|
|
2757
3066
|
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2758
3067
|
#else
|
|
2759
3068
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
|
2760
3069
|
int * x_qs = (int *) x_tile;
|
|
2761
3070
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2762
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3071
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2763
3072
|
|
|
2764
3073
|
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
|
|
2765
3074
|
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2779,13 +3088,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2779
3088
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
|
2780
3089
|
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
|
2781
3090
|
|
|
2782
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3091
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2783
3092
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
|
2784
3093
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
|
2785
3094
|
#else
|
|
2786
3095
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
|
2787
3096
|
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
|
|
2788
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3097
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2789
3098
|
}
|
|
2790
3099
|
|
|
2791
3100
|
constexpr int rows_per_warp = warp_size / 8;
|
|
@@ -2804,11 +3113,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
2804
3113
|
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
|
2805
3114
|
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
|
2806
3115
|
|
|
2807
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3116
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2808
3117
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
|
2809
3118
|
#else
|
|
2810
3119
|
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
|
2811
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3120
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2812
3121
|
}
|
|
2813
3122
|
}
|
|
2814
3123
|
|
|
@@ -2848,9 +3157,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
|
2848
3157
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2849
3158
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2850
3159
|
|
|
2851
|
-
#if defined(AMD_MFMA_AVAILABLE)
|
|
3160
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2852
3161
|
constexpr int tileC_IJ = mmq_get_granularity_device(0);
|
|
2853
|
-
typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
|
|
3162
|
+
typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
|
2854
3163
|
constexpr int rows_per_warp = granularity;
|
|
2855
3164
|
#else
|
|
2856
3165
|
typedef tile<16, 8, int> tile_C;
|
|
@@ -2859,11 +3168,11 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
|
2859
3168
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2860
3169
|
|
|
2861
3170
|
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
|
2862
|
-
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
3171
|
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2863
3172
|
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
|
2864
3173
|
#else
|
|
2865
3174
|
GGML_UNUSED(nwarps);
|
|
2866
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3175
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
2867
3176
|
|
|
2868
3177
|
#pragma unroll
|
|
2869
3178
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
@@ -2937,8 +3246,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
|
|
2937
3246
|
template <int mmq_x, int mmq_y, bool need_check>
|
|
2938
3247
|
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|
2939
3248
|
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
|
3249
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
3250
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
|
3251
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
|
|
3252
|
+
#else
|
|
2940
3253
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
|
2941
3254
|
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
3255
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
2942
3256
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2943
3257
|
};
|
|
2944
3258
|
|
|
@@ -3063,25 +3377,34 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
3063
3377
|
int * tile_y = data_mul_mat_q + mmq_x;
|
|
3064
3378
|
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
|
|
3065
3379
|
|
|
3066
|
-
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3380
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
3067
3381
|
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
|
|
3068
3382
|
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
|
|
3069
3383
|
#else
|
|
3070
3384
|
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
|
|
3071
3385
|
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
|
3072
|
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
3386
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
3073
3387
|
|
|
3074
|
-
|
|
3388
|
+
#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
3389
|
+
// FP4 tile stores 8 blocks
|
|
3390
|
+
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
|
|
3391
|
+
#else
|
|
3392
|
+
constexpr int ne_block = 4 * QK8_1;
|
|
3393
|
+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
3394
|
+
|
|
3395
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3396
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3075
3397
|
|
|
3076
3398
|
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
|
3077
3399
|
|
|
3400
|
+
constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
|
|
3401
|
+
|
|
3078
3402
|
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
|
3079
3403
|
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
|
3080
|
-
|
|
3081
3404
|
{
|
|
3082
|
-
const int * by0 = y + ncols_y*(kb0*
|
|
3405
|
+
const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
|
|
3083
3406
|
#pragma unroll
|
|
3084
|
-
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
|
3407
|
+
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
|
|
3085
3408
|
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
|
3086
3409
|
|
|
3087
3410
|
tile_y[l] = by0[l];
|
|
@@ -3095,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
3095
3418
|
__syncthreads();
|
|
3096
3419
|
|
|
3097
3420
|
{
|
|
3098
|
-
const int * by0 = y + ncols_y*(kb0*
|
|
3421
|
+
const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
|
|
3099
3422
|
#pragma unroll
|
|
3100
|
-
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
|
3423
|
+
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
|
|
3101
3424
|
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
|
3102
3425
|
|
|
3103
3426
|
tile_y[l] = by0[l];
|
|
@@ -3229,8 +3552,10 @@ static __global__ void mul_mat_q(
|
|
|
3229
3552
|
}
|
|
3230
3553
|
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3231
3554
|
|
|
3555
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3556
|
+
|
|
3232
3557
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
3233
|
-
constexpr int blocks_per_iter =
|
|
3558
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3234
3559
|
|
|
3235
3560
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
3236
3561
|
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
|
@@ -3291,7 +3616,7 @@ static __global__ void mul_mat_q(
|
|
|
3291
3616
|
__syncthreads();
|
|
3292
3617
|
}
|
|
3293
3618
|
|
|
3294
|
-
offset_y
|
|
3619
|
+
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
|
|
3295
3620
|
offset_dst += it*mmq_y;
|
|
3296
3621
|
|
|
3297
3622
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
@@ -3358,7 +3683,7 @@ static __global__ void mul_mat_q(
|
|
|
3358
3683
|
__syncthreads();
|
|
3359
3684
|
}
|
|
3360
3685
|
|
|
3361
|
-
offset_y
|
|
3686
|
+
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
|
|
3362
3687
|
offset_dst += it*mmq_y;
|
|
3363
3688
|
|
|
3364
3689
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
@@ -3381,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3381
3706
|
const int ncols_max) {
|
|
3382
3707
|
constexpr int mmq_y = get_mmq_y_device();
|
|
3383
3708
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
3384
|
-
constexpr int
|
|
3709
|
+
constexpr int ITER_K = get_iter_k(type);
|
|
3710
|
+
|
|
3711
|
+
constexpr int blocks_per_iter = ITER_K / qk;
|
|
3385
3712
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
3386
3713
|
|
|
3387
3714
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
@@ -3494,7 +3821,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
3494
3821
|
const int col_diff = col_high - col_low;
|
|
3495
3822
|
|
|
3496
3823
|
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
|
|
3497
|
-
ids_dst_shared[j] = ids_dst[col_low + j];
|
|
3824
|
+
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
|
3498
3825
|
}
|
|
3499
3826
|
__syncthreads();
|
|
3500
3827
|
|
|
@@ -3538,8 +3865,8 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
|
|
|
3538
3865
|
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
|
3539
3866
|
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
|
3540
3867
|
const size_t nbs_ids = mmq_x*sizeof(int);
|
|
3541
|
-
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
|
3542
|
-
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
|
3868
|
+
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
|
3869
|
+
const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
|
|
3543
3870
|
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
|
3544
3871
|
}
|
|
3545
3872
|
|
|
@@ -3755,4 +4082,4 @@ void ggml_cuda_op_mul_mat_q(
|
|
|
3755
4082
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
|
3756
4083
|
const int64_t src1_padded_row_size, cudaStream_t stream);
|
|
3757
4084
|
|
|
3758
|
-
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
|
|
4085
|
+
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
|