whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -18,6 +18,10 @@
|
|
|
18
18
|
|
|
19
19
|
#include "common.cuh"
|
|
20
20
|
|
|
21
|
+
// On Volta each warp is doing 4 8x8 mma operations in parallel.
|
|
22
|
+
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
|
|
23
|
+
// However, the i indices in this file are by default permuted to simplify the index calculations.
|
|
24
|
+
// #define GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
21
25
|
|
|
22
26
|
#if CUDART_VERSION >= 11080
|
|
23
27
|
|
|
@@ -64,15 +68,59 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
|
|
64
68
|
|
|
65
69
|
namespace ggml_cuda_mma {
|
|
66
70
|
|
|
71
|
+
// Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
|
|
72
|
+
// effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
|
|
73
|
+
// In those cases the data can be split in different ways across the warp.
|
|
74
|
+
enum data_layout {
|
|
75
|
+
// By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
|
|
76
|
+
// For the A/C matrices this means I major == row major, J major == column major.
|
|
77
|
+
// For the B matrix this means I major == column major, J major == row major.
|
|
78
|
+
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
|
79
|
+
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
|
|
80
|
+
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
|
81
|
+
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
|
|
82
|
+
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
|
83
|
+
};
|
|
84
|
+
// Implemented mma combinations are:
|
|
85
|
+
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
|
86
|
+
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
|
87
|
+
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
|
88
|
+
|
|
89
|
+
static constexpr bool is_i_major(const data_layout dl) {
|
|
90
|
+
return dl == DATA_LAYOUT_I_MAJOR ||
|
|
91
|
+
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
static constexpr __device__ data_layout get_input_data_layout() {
|
|
95
|
+
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
96
|
+
return DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
97
|
+
#else
|
|
98
|
+
return DATA_LAYOUT_I_MAJOR;
|
|
99
|
+
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
|
103
|
+
struct tile {};
|
|
104
|
+
|
|
67
105
|
template <int I_, int J_, typename T>
|
|
68
|
-
struct tile {
|
|
69
|
-
static constexpr int
|
|
70
|
-
static constexpr int
|
|
106
|
+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
|
|
107
|
+
static constexpr int I = I_;
|
|
108
|
+
static constexpr int J = J_;
|
|
109
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
71
110
|
|
|
72
|
-
#if defined(
|
|
111
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
73
112
|
static constexpr int ne = I * J / 64;
|
|
74
113
|
T x[ne] = {0};
|
|
75
114
|
|
|
115
|
+
static constexpr __device__ bool supported() {
|
|
116
|
+
if (I == 64 && J == 2) return true;
|
|
117
|
+
if (I == 16 && J == 8) return true;
|
|
118
|
+
if (I == 32 && J == 4) return true;
|
|
119
|
+
if (I == 16 && J == 16) return true;
|
|
120
|
+
if (I == 32 && J == 32) return true;
|
|
121
|
+
return false;
|
|
122
|
+
}
|
|
123
|
+
|
|
76
124
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
77
125
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
78
126
|
return threadIdx.x % 16;
|
|
@@ -81,11 +129,12 @@ namespace ggml_cuda_mma {
|
|
|
81
129
|
} else if constexpr (I == 32 && J == 4) {
|
|
82
130
|
return threadIdx.x % 32;
|
|
83
131
|
} else if constexpr (I == 16 && J == 16) {
|
|
84
|
-
return
|
|
132
|
+
return threadIdx.x % 16;
|
|
85
133
|
} else if constexpr (I == 32 && J == 32) {
|
|
86
|
-
return
|
|
134
|
+
return threadIdx.x % 32;
|
|
87
135
|
} else {
|
|
88
|
-
|
|
136
|
+
NO_DEVICE_CODE;
|
|
137
|
+
return -1;
|
|
89
138
|
}
|
|
90
139
|
}
|
|
91
140
|
|
|
@@ -97,26 +146,109 @@ namespace ggml_cuda_mma {
|
|
|
97
146
|
} else if constexpr (I == 32 && J == 4) {
|
|
98
147
|
return 2 * (threadIdx.x / 32) + l;
|
|
99
148
|
} else if constexpr (I == 16 && J == 16) {
|
|
100
|
-
return threadIdx.x
|
|
149
|
+
return 4 * (threadIdx.x / 16) + l;
|
|
101
150
|
} else if constexpr (I == 32 && J == 32) {
|
|
102
|
-
return threadIdx.x %
|
|
151
|
+
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
|
103
152
|
} else {
|
|
104
|
-
|
|
153
|
+
NO_DEVICE_CODE;
|
|
154
|
+
return -1;
|
|
105
155
|
}
|
|
106
156
|
}
|
|
157
|
+
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
158
|
+
static constexpr int ne = I * J / 32;
|
|
159
|
+
T x[ne] = {0};
|
|
160
|
+
|
|
161
|
+
static constexpr __device__ bool supported() {
|
|
162
|
+
if (I == 32 && J == 8) return true;
|
|
163
|
+
return false;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
167
|
+
if constexpr (I == 32 && J == 8) {
|
|
168
|
+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
169
|
+
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
|
|
107
170
|
#else
|
|
171
|
+
return (l & 2) + (threadIdx.x & ~2);
|
|
172
|
+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
173
|
+
} else {
|
|
174
|
+
NO_DEVICE_CODE;
|
|
175
|
+
return -1;
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
180
|
+
if constexpr (I == 32 && J == 8) {
|
|
181
|
+
return (threadIdx.x & 2) + (l & (4 + 1));
|
|
182
|
+
} else {
|
|
183
|
+
NO_DEVICE_CODE;
|
|
184
|
+
return -1;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
108
188
|
static constexpr int ne = I * J / 32;
|
|
109
189
|
T x[ne] = {0};
|
|
110
190
|
|
|
191
|
+
static constexpr __device__ bool supported() {
|
|
192
|
+
if (I == 16 && J == 16) return true;
|
|
193
|
+
if (I == 16 && J == 8) return true;
|
|
194
|
+
if (I == 16 && J == 4) return true;
|
|
195
|
+
return false;
|
|
196
|
+
}
|
|
197
|
+
|
|
111
198
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
112
|
-
if constexpr (
|
|
199
|
+
if constexpr (supported()) {
|
|
200
|
+
return threadIdx.x % 16;
|
|
201
|
+
} else {
|
|
202
|
+
NO_DEVICE_CODE;
|
|
203
|
+
return -1;
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
208
|
+
if constexpr (I == 16 && J == 16) {
|
|
209
|
+
// matrix C
|
|
210
|
+
#if defined(RDNA3)
|
|
211
|
+
return 2 * l + (threadIdx.x / 16);
|
|
212
|
+
#else
|
|
213
|
+
return ne * (threadIdx.x / 16) + l;
|
|
214
|
+
#endif // defined(RDNA3)
|
|
215
|
+
} else if constexpr (I == 16 && J == 8) {
|
|
216
|
+
// mmq input for RDNA4
|
|
217
|
+
return ne * (threadIdx.x / 16) + l;
|
|
218
|
+
} else if constexpr (I == 16 && J == 4) {
|
|
219
|
+
return ne * (threadIdx.x / 16) + l;
|
|
220
|
+
} else {
|
|
221
|
+
NO_DEVICE_CODE;
|
|
222
|
+
return -1;
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
#else
|
|
226
|
+
static constexpr int ne = I * J / 32;
|
|
227
|
+
T x[ne] = {0};
|
|
228
|
+
|
|
229
|
+
static constexpr __device__ bool supported() {
|
|
230
|
+
if (I == 8 && J == 4) return true;
|
|
231
|
+
if (I == 8 && J == 8) return true;
|
|
232
|
+
if (I == 16 && J == 8) return true;
|
|
233
|
+
if (I == 16 && J == 16) return true;
|
|
234
|
+
if (I == 32 && J == 8) return true;
|
|
235
|
+
return false;
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
239
|
+
if constexpr (I == 8 && J == 4) {
|
|
240
|
+
return threadIdx.x / 4;
|
|
241
|
+
} else if constexpr (I == 8 && J == 8) {
|
|
113
242
|
return threadIdx.x / 4;
|
|
114
243
|
} else if constexpr (I == 16 && J == 8) {
|
|
115
|
-
return (l / 2) * 8 + threadIdx.x / 4;
|
|
244
|
+
return ((l / 2) * 8) + (threadIdx.x / 4);
|
|
116
245
|
} else if constexpr (I == 16 && J == 16) {
|
|
117
|
-
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
|
246
|
+
return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
|
|
247
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
248
|
+
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
|
|
118
249
|
} else {
|
|
119
|
-
|
|
250
|
+
NO_DEVICE_CODE;
|
|
251
|
+
return -1;
|
|
120
252
|
}
|
|
121
253
|
}
|
|
122
254
|
|
|
@@ -124,82 +256,354 @@ namespace ggml_cuda_mma {
|
|
|
124
256
|
if constexpr (I == 8 && J == 4) {
|
|
125
257
|
return threadIdx.x % 4;
|
|
126
258
|
} else if constexpr (I == 8 && J == 8) {
|
|
127
|
-
return
|
|
259
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
128
260
|
} else if constexpr (I == 16 && J == 8) {
|
|
129
|
-
return
|
|
261
|
+
return ((threadIdx.x % 4) * 2) + (l % 2);
|
|
130
262
|
} else if constexpr (I == 16 && J == 16) {
|
|
131
|
-
return
|
|
263
|
+
return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
|
|
264
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
265
|
+
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
|
|
132
266
|
} else {
|
|
133
|
-
|
|
267
|
+
NO_DEVICE_CODE;
|
|
268
|
+
return -1;
|
|
134
269
|
}
|
|
135
270
|
}
|
|
136
271
|
#endif // defined(GGML_USE_HIP)
|
|
137
272
|
};
|
|
138
273
|
|
|
139
274
|
template <int I_, int J_>
|
|
140
|
-
struct tile<I_, J_, half2> {
|
|
141
|
-
static constexpr int
|
|
142
|
-
static constexpr int
|
|
275
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
|
|
276
|
+
static constexpr int I = I_;
|
|
277
|
+
static constexpr int J = J_;
|
|
278
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
279
|
+
|
|
280
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
281
|
+
static constexpr int ne = I * J / WARP_SIZE;
|
|
282
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
283
|
+
|
|
284
|
+
static constexpr __device__ bool supported() {
|
|
285
|
+
if (I == 32 && J == 4) return true;
|
|
286
|
+
return false;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
290
|
+
if constexpr (I == 32 && J == 4) {
|
|
291
|
+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
292
|
+
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
|
293
|
+
#else
|
|
294
|
+
return threadIdx.x;
|
|
295
|
+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
|
296
|
+
} else {
|
|
297
|
+
NO_DEVICE_CODE;
|
|
298
|
+
return -1;
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
303
|
+
if constexpr (I == 32 && J == 4) {
|
|
304
|
+
return l;
|
|
305
|
+
} else {
|
|
306
|
+
NO_DEVICE_CODE;
|
|
307
|
+
return -1;
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
311
|
+
static constexpr int ne = I * J / 32;
|
|
312
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
313
|
+
|
|
314
|
+
static constexpr __device__ bool supported() {
|
|
315
|
+
if (I == 16 && J == 8) return true;
|
|
316
|
+
return false;
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
320
|
+
if constexpr (I == 16 && J == 8) {
|
|
321
|
+
return threadIdx.x % 16;
|
|
322
|
+
} else {
|
|
323
|
+
NO_DEVICE_CODE;
|
|
324
|
+
return -1;
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
329
|
+
if constexpr (I == 16 && J == 8) {
|
|
330
|
+
return 4 * (threadIdx.x / 16) + l;
|
|
331
|
+
} else {
|
|
332
|
+
NO_DEVICE_CODE;
|
|
333
|
+
return -1;
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
#else
|
|
143
337
|
static constexpr int ne = I * J / WARP_SIZE;
|
|
144
338
|
half2 x[ne] = {{0.0f, 0.0f}};
|
|
145
339
|
|
|
340
|
+
static constexpr __device__ bool supported() {
|
|
341
|
+
if (I == 8 && J == 4) return true;
|
|
342
|
+
if (I == 8 && J == 8) return true;
|
|
343
|
+
if (I == 16 && J == 8) return true;
|
|
344
|
+
if (I == 16 && J == 16) return true;
|
|
345
|
+
if (I == 32 && J == 8) return true;
|
|
346
|
+
return false;
|
|
347
|
+
}
|
|
348
|
+
|
|
146
349
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
147
350
|
if constexpr (I == 8 && J == 8) {
|
|
148
351
|
return threadIdx.x / 4;
|
|
149
352
|
} else if constexpr (I == 16 && J == 4) {
|
|
150
|
-
return l * 8 + threadIdx.x / 4;
|
|
353
|
+
return (l * 8) + (threadIdx.x / 4);
|
|
151
354
|
} else if constexpr (I == 16 && J == 8) {
|
|
152
|
-
return (l % 2) * 8 + threadIdx.x / 4;
|
|
355
|
+
return ((l % 2) * 8) + (threadIdx.x / 4);
|
|
356
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
357
|
+
return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
|
|
153
358
|
} else {
|
|
154
|
-
|
|
359
|
+
NO_DEVICE_CODE;
|
|
360
|
+
return -1;
|
|
155
361
|
}
|
|
156
362
|
}
|
|
157
363
|
|
|
158
364
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
159
365
|
if constexpr (I == 8 && J == 8) {
|
|
160
|
-
return l * 4 + threadIdx.x % 4;
|
|
366
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
161
367
|
} else if constexpr (I == 16 && J == 4) {
|
|
162
368
|
return threadIdx.x % 4;
|
|
163
369
|
} else if constexpr (I == 16 && J == 8) {
|
|
164
|
-
return (l / 2) * 4 + threadIdx.x % 4;
|
|
370
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
371
|
+
} else if constexpr (I == 32 && J == 8) {
|
|
372
|
+
return ((l & 2) * 2) + (threadIdx.x % 4);
|
|
165
373
|
} else {
|
|
166
|
-
|
|
374
|
+
NO_DEVICE_CODE;
|
|
375
|
+
return -1;
|
|
167
376
|
}
|
|
168
377
|
}
|
|
378
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
169
379
|
};
|
|
170
380
|
|
|
171
381
|
template <int I_, int J_>
|
|
172
|
-
struct tile<I_, J_, nv_bfloat162> {
|
|
173
|
-
static constexpr int
|
|
174
|
-
static constexpr int
|
|
382
|
+
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
|
|
383
|
+
static constexpr int I = I_;
|
|
384
|
+
static constexpr int J = J_;
|
|
385
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
|
386
|
+
|
|
387
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
388
|
+
static constexpr int ne = I * J / 32;
|
|
389
|
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
390
|
+
|
|
391
|
+
static constexpr __device__ bool supported() {
|
|
392
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
396
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
400
|
+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
|
401
|
+
}
|
|
402
|
+
#else
|
|
175
403
|
static constexpr int ne = I * J / WARP_SIZE;
|
|
176
404
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
177
405
|
|
|
406
|
+
static constexpr __device__ bool supported() {
|
|
407
|
+
if (I == 8 && J == 8) return true;
|
|
408
|
+
if (I == 16 && J == 4) return true;
|
|
409
|
+
if (I == 16 && J == 8) return true;
|
|
410
|
+
return false;
|
|
411
|
+
}
|
|
412
|
+
|
|
178
413
|
static __device__ __forceinline__ int get_i(const int l) {
|
|
179
414
|
if constexpr (I == 8 && J == 8) {
|
|
180
415
|
return threadIdx.x / 4;
|
|
181
416
|
} else if constexpr (I == 16 && J == 4) {
|
|
182
|
-
return l * 8 + threadIdx.x / 4;
|
|
417
|
+
return (l * 8) + (threadIdx.x / 4);
|
|
183
418
|
} else if constexpr (I == 16 && J == 8) {
|
|
184
|
-
return (l % 2) * 8 + threadIdx.x / 4;
|
|
419
|
+
return ((l % 2) * 8) + (threadIdx.x / 4);
|
|
185
420
|
} else {
|
|
186
|
-
|
|
421
|
+
NO_DEVICE_CODE;
|
|
422
|
+
return -1;
|
|
187
423
|
}
|
|
188
424
|
}
|
|
189
425
|
|
|
190
426
|
static __device__ __forceinline__ int get_j(const int l) {
|
|
191
427
|
if constexpr (I == 8 && J == 8) {
|
|
192
|
-
return l * 4 + threadIdx.x % 4;
|
|
428
|
+
return (l * 4) + (threadIdx.x % 4);
|
|
193
429
|
} else if constexpr (I == 16 && J == 4) {
|
|
194
430
|
return threadIdx.x % 4;
|
|
195
431
|
} else if constexpr (I == 16 && J == 8) {
|
|
196
|
-
return (l / 2) * 4 + threadIdx.x % 4;
|
|
432
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
433
|
+
} else {
|
|
434
|
+
NO_DEVICE_CODE;
|
|
435
|
+
return -1;
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
439
|
+
};
|
|
440
|
+
|
|
441
|
+
template <int I_, int J_, typename T>
|
|
442
|
+
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
|
|
443
|
+
static constexpr int I = I_;
|
|
444
|
+
static constexpr int J = J_;
|
|
445
|
+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
|
|
446
|
+
|
|
447
|
+
static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
|
|
448
|
+
T x[ne] = {0};
|
|
449
|
+
|
|
450
|
+
static constexpr __device__ bool supported() {
|
|
451
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
455
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
459
|
+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
|
460
|
+
}
|
|
461
|
+
};
|
|
462
|
+
|
|
463
|
+
template <int I_, int J_, typename T>
|
|
464
|
+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
465
|
+
static constexpr int I = I_;
|
|
466
|
+
static constexpr int J = J_;
|
|
467
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
468
|
+
|
|
469
|
+
// RDNA3
|
|
470
|
+
static constexpr int ne = I * J / 32 * 2;
|
|
471
|
+
|
|
472
|
+
T x[ne] = {0};
|
|
473
|
+
|
|
474
|
+
static constexpr __device__ bool supported() {
|
|
475
|
+
if (I == 16 && J == 16) return true;
|
|
476
|
+
if (I == 16 && J == 8) return true;
|
|
477
|
+
if (I == 16 && J == 4) return true;
|
|
478
|
+
return false;
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
|
482
|
+
if constexpr (supported()) {
|
|
483
|
+
return threadIdx.x % 16;
|
|
484
|
+
} else {
|
|
485
|
+
NO_DEVICE_CODE;
|
|
486
|
+
return -1;
|
|
487
|
+
}
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
491
|
+
if constexpr (supported()) {
|
|
492
|
+
return l;
|
|
493
|
+
} else {
|
|
494
|
+
NO_DEVICE_CODE;
|
|
495
|
+
return -1;
|
|
496
|
+
}
|
|
497
|
+
}
|
|
498
|
+
};
|
|
499
|
+
|
|
500
|
+
template <int I_, int J_>
|
|
501
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
502
|
+
static constexpr int I = I_;
|
|
503
|
+
static constexpr int J = J_;
|
|
504
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
505
|
+
#if defined(RDNA3)
|
|
506
|
+
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
|
507
|
+
|
|
508
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
509
|
+
|
|
510
|
+
static constexpr __device__ bool supported() {
|
|
511
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
515
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
519
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
|
520
|
+
}
|
|
521
|
+
#else // Volta
|
|
522
|
+
static constexpr int ne = I * J / (WARP_SIZE/4);
|
|
523
|
+
|
|
524
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
525
|
+
|
|
526
|
+
static constexpr __device__ bool supported() {
|
|
527
|
+
if (I == 8 && J == 4) return true;
|
|
528
|
+
return false;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
|
532
|
+
if constexpr (I == 8 && J == 4) {
|
|
533
|
+
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
|
197
534
|
} else {
|
|
198
|
-
|
|
535
|
+
NO_DEVICE_CODE;
|
|
536
|
+
return -1;
|
|
199
537
|
}
|
|
200
538
|
}
|
|
539
|
+
|
|
540
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
541
|
+
if constexpr (I == 8 && J == 4) {
|
|
542
|
+
return l;
|
|
543
|
+
} else {
|
|
544
|
+
NO_DEVICE_CODE;
|
|
545
|
+
return -1;
|
|
546
|
+
}
|
|
547
|
+
}
|
|
548
|
+
#endif // defined(RDNA3)
|
|
549
|
+
};
|
|
550
|
+
|
|
551
|
+
template <int I_, int J_>
|
|
552
|
+
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
|
553
|
+
static constexpr int I = I_;
|
|
554
|
+
static constexpr int J = J_;
|
|
555
|
+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
|
556
|
+
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
|
557
|
+
|
|
558
|
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
|
559
|
+
|
|
560
|
+
static constexpr __device__ bool supported() {
|
|
561
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
565
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
569
|
+
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
|
570
|
+
}
|
|
201
571
|
};
|
|
202
572
|
|
|
573
|
+
template <int I_, int J_>
|
|
574
|
+
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
|
|
575
|
+
static constexpr int I = I_;
|
|
576
|
+
static constexpr int J = J_;
|
|
577
|
+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
|
|
578
|
+
static constexpr int ne = I * J / (WARP_SIZE/4);
|
|
579
|
+
|
|
580
|
+
half2 x[ne] = {{0.0f, 0.0f}};
|
|
581
|
+
|
|
582
|
+
static constexpr __device__ bool supported() {
|
|
583
|
+
if (I == 8 && J == 4) return true;
|
|
584
|
+
return false;
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
static __device__ __forceinline__ int get_i(const int l) {
|
|
588
|
+
if constexpr (I == 8 && J == 4) {
|
|
589
|
+
return ((l / 2) * 4) + (threadIdx.x % 4);
|
|
590
|
+
} else {
|
|
591
|
+
NO_DEVICE_CODE;
|
|
592
|
+
return -1;
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
static __device__ __forceinline__ int get_j(const int l) {
|
|
597
|
+
if constexpr (I == 8 && J == 4) {
|
|
598
|
+
return ((threadIdx.x / 16) * 2) + (l % 2);
|
|
599
|
+
} else {
|
|
600
|
+
NO_DEVICE_CODE;
|
|
601
|
+
return -1;
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
};
|
|
605
|
+
|
|
606
|
+
#if defined(TURING_MMA_AVAILABLE)
|
|
203
607
|
template <int I, int J>
|
|
204
608
|
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
205
609
|
tile<I, J/2, half2> ret;
|
|
@@ -217,9 +621,26 @@ namespace ggml_cuda_mma {
|
|
|
217
621
|
|
|
218
622
|
return ret;
|
|
219
623
|
}
|
|
624
|
+
#else // Volta
|
|
625
|
+
template <int I, int J>
|
|
626
|
+
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
|
627
|
+
tile<I, J/2, half2> ret;
|
|
628
|
+
#pragma unroll
|
|
629
|
+
for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
|
|
630
|
+
ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
|
631
|
+
ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
|
|
632
|
+
|
|
633
|
+
// On Volta FP16 and FP32 tiles have a different memory layout,
|
|
634
|
+
// for the conversion threads with an offset of 2 need to exchange half their values:
|
|
635
|
+
ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
|
|
636
|
+
0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
|
|
637
|
+
}
|
|
638
|
+
return ret;
|
|
639
|
+
}
|
|
640
|
+
#endif // defined(TURING_MMA_AVAILABLE)
|
|
220
641
|
|
|
221
|
-
template <int I, int J, typename T>
|
|
222
|
-
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
642
|
+
template <int I, int J, typename T, data_layout dl>
|
|
643
|
+
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
|
223
644
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
224
645
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
|
225
646
|
#pragma unroll
|
|
@@ -227,9 +648,28 @@ namespace ggml_cuda_mma {
|
|
|
227
648
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
228
649
|
}
|
|
229
650
|
} else {
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
651
|
+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
652
|
+
}
|
|
653
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
654
|
+
// All wmma layout has contiguous data when i-major.
|
|
655
|
+
if constexpr (is_i_major(dl)) {
|
|
656
|
+
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
|
|
657
|
+
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
|
|
658
|
+
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
|
|
659
|
+
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
|
|
660
|
+
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
|
|
661
|
+
#pragma unroll
|
|
662
|
+
for (int i = 0; i < aligned_copy_count; ++i) {
|
|
663
|
+
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
|
|
664
|
+
}
|
|
665
|
+
} else {
|
|
666
|
+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
667
|
+
}
|
|
668
|
+
} else {
|
|
669
|
+
#pragma unroll
|
|
670
|
+
for (int l = 0; l < t.ne; ++l) {
|
|
671
|
+
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
|
672
|
+
}
|
|
233
673
|
}
|
|
234
674
|
#else
|
|
235
675
|
#pragma unroll
|
|
@@ -263,25 +703,63 @@ namespace ggml_cuda_mma {
|
|
|
263
703
|
: "=r"(xi[0]), "=r"(xi[1])
|
|
264
704
|
: "l"(xs));
|
|
265
705
|
#else
|
|
266
|
-
|
|
267
|
-
|
|
706
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
707
|
+
GGML_UNUSED_VARS(t, xs0, stride);
|
|
708
|
+
NO_DEVICE_CODE;
|
|
709
|
+
#else
|
|
710
|
+
load_generic(t, xs0, stride);
|
|
711
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
268
712
|
#endif // TURING_MMA_AVAILABLE
|
|
269
713
|
}
|
|
270
714
|
|
|
271
|
-
template <typename T>
|
|
715
|
+
template <typename T, data_layout dl>
|
|
272
716
|
static __device__ __forceinline__ void load_ldmatrix(
|
|
273
|
-
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
717
|
+
tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
|
274
718
|
#if defined(TURING_MMA_AVAILABLE)
|
|
275
719
|
int * xi = (int * ) t.x;
|
|
276
720
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
|
277
721
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
|
278
722
|
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
|
279
723
|
: "l"(xs));
|
|
724
|
+
#else
|
|
725
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
726
|
+
#if 1
|
|
727
|
+
// TODO: more generic handling
|
|
728
|
+
static_assert(sizeof(T) == 4, "bad type size");
|
|
729
|
+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
|
|
730
|
+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
|
|
731
|
+
#else
|
|
732
|
+
load_generic(t, xs0, stride);
|
|
733
|
+
#endif // 1
|
|
280
734
|
#else
|
|
281
735
|
load_generic(t, xs0, stride);
|
|
736
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
282
737
|
#endif // TURING_MMA_AVAILABLE
|
|
283
738
|
}
|
|
284
739
|
|
|
740
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
741
|
+
tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
742
|
+
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
746
|
+
tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
747
|
+
#pragma unroll
|
|
748
|
+
for (int l0 = 0; l0 < t.ne; l0 += 2) {
|
|
749
|
+
ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
|
|
750
|
+
}
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
static __device__ __forceinline__ void load_ldmatrix(
|
|
754
|
+
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
|
|
755
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
756
|
+
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
|
757
|
+
#else
|
|
758
|
+
GGML_UNUSED_VARS(t, xs0, stride);
|
|
759
|
+
NO_DEVICE_CODE;
|
|
760
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
761
|
+
}
|
|
762
|
+
|
|
285
763
|
template <typename T>
|
|
286
764
|
static __device__ __forceinline__ void load_ldmatrix_trans(
|
|
287
765
|
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
|
@@ -406,8 +884,9 @@ namespace ggml_cuda_mma {
|
|
|
406
884
|
#endif // TURING_MMA_AVAILABLE
|
|
407
885
|
}
|
|
408
886
|
|
|
887
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
409
888
|
static __device__ __forceinline__ void mma(
|
|
410
|
-
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
|
|
889
|
+
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
|
|
411
890
|
#ifdef AMPERE_MMA_AVAILABLE
|
|
412
891
|
const int * Axi = (const int *) A.x;
|
|
413
892
|
const int * Bxi = (const int *) B.x;
|
|
@@ -421,6 +900,27 @@ namespace ggml_cuda_mma {
|
|
|
421
900
|
#endif // AMPERE_MMA_AVAILABLE
|
|
422
901
|
}
|
|
423
902
|
|
|
903
|
+
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
|
904
|
+
const tile<16, 8, int> & A,
|
|
905
|
+
const tile<8, 8, int> & B,
|
|
906
|
+
uint32_t a_scale,
|
|
907
|
+
uint32_t b_scale) {
|
|
908
|
+
#ifdef BLACKWELL_MMA_AVAILABLE
|
|
909
|
+
const int * Axi = (const int *) A.x;
|
|
910
|
+
const int * Bxi = (const int *) B.x;
|
|
911
|
+
float * Dxi = (float *) D.x;
|
|
912
|
+
|
|
913
|
+
asm volatile(
|
|
914
|
+
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
|
915
|
+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
|
916
|
+
"%10, {0, 0}, %11, {0, 0};"
|
|
917
|
+
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
|
918
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
|
919
|
+
#else
|
|
920
|
+
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
|
|
921
|
+
#endif // BLACKWELL_MMA_AVAILABLE
|
|
922
|
+
}
|
|
923
|
+
|
|
424
924
|
static __device__ __forceinline__ void mma(
|
|
425
925
|
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
|
426
926
|
#ifdef TURING_MMA_AVAILABLE
|
|
@@ -461,8 +961,9 @@ namespace ggml_cuda_mma {
|
|
|
461
961
|
#endif // AMPERE_MMA_AVAILABLE
|
|
462
962
|
}
|
|
463
963
|
|
|
964
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
464
965
|
static __device__ __forceinline__ void mma(
|
|
465
|
-
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
|
966
|
+
tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
|
|
466
967
|
#ifdef TURING_MMA_AVAILABLE
|
|
467
968
|
const int * Axi = (const int *) A.x;
|
|
468
969
|
const int * Bxi = (const int *) B.x;
|
|
@@ -489,14 +990,62 @@ namespace ggml_cuda_mma {
|
|
|
489
990
|
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
490
991
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
|
491
992
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
993
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
994
|
+
#if defined(RDNA4)
|
|
995
|
+
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
|
996
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
997
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
998
|
+
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
|
999
|
+
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
|
1000
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
|
1001
|
+
#elif defined(RDNA3)
|
|
1002
|
+
using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
|
|
1003
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1004
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1005
|
+
const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
|
|
1006
|
+
const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
|
|
1007
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
|
|
1008
|
+
#else
|
|
1009
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1010
|
+
NO_DEVICE_CODE;
|
|
1011
|
+
#endif // RDNA4
|
|
492
1012
|
#else
|
|
493
1013
|
GGML_UNUSED_VARS(D, A, B);
|
|
494
1014
|
NO_DEVICE_CODE;
|
|
495
1015
|
#endif // TURING_MMA_AVAILABLE
|
|
496
1016
|
}
|
|
497
1017
|
|
|
1018
|
+
template <data_layout dl_ab, data_layout dl_d>
|
|
498
1019
|
static __device__ __forceinline__ void mma(
|
|
499
|
-
tile<16, 16,
|
|
1020
|
+
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
|
|
1021
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1022
|
+
#if defined(RDNA4)
|
|
1023
|
+
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
|
1024
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1025
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1026
|
+
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
|
|
1027
|
+
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
|
|
1028
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
|
|
1029
|
+
#elif defined(RDNA3)
|
|
1030
|
+
using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
|
|
1031
|
+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
|
1032
|
+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
|
1033
|
+
const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
|
|
1034
|
+
const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
|
|
1035
|
+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
|
|
1036
|
+
#else
|
|
1037
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1038
|
+
NO_DEVICE_CODE;
|
|
1039
|
+
#endif // RDNA4
|
|
1040
|
+
#else
|
|
1041
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1042
|
+
NO_DEVICE_CODE;
|
|
1043
|
+
#endif // AMPERE_MMA_AVAILABLE
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
template <data_layout dl_d, data_layout dl_ab>
|
|
1047
|
+
static __device__ __forceinline__ void mma(
|
|
1048
|
+
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
|
|
500
1049
|
#if defined(AMD_MFMA_AVAILABLE)
|
|
501
1050
|
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
502
1051
|
int32x4_t * acc = (int32x4_t *) D.x;
|
|
@@ -515,6 +1064,59 @@ namespace ggml_cuda_mma {
|
|
|
515
1064
|
acc[0],
|
|
516
1065
|
0, 0, 0);
|
|
517
1066
|
#endif // defined(CDNA3)
|
|
1067
|
+
|
|
1068
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
1069
|
+
|
|
1070
|
+
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
|
1071
|
+
int32x8_t * acc = (int32x8_t *) D.x;
|
|
1072
|
+
|
|
1073
|
+
#if defined(RDNA4)
|
|
1074
|
+
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
|
1075
|
+
int32x2_t * a_vec = (int32x2_t *) A.x;
|
|
1076
|
+
int32x2_t * b_vec = (int32x2_t *) B.x;
|
|
1077
|
+
|
|
1078
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1079
|
+
true,
|
|
1080
|
+
a_vec[0],
|
|
1081
|
+
true,
|
|
1082
|
+
b_vec[0],
|
|
1083
|
+
acc[0],
|
|
1084
|
+
true
|
|
1085
|
+
);
|
|
1086
|
+
|
|
1087
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1088
|
+
true,
|
|
1089
|
+
a_vec[1],
|
|
1090
|
+
true,
|
|
1091
|
+
b_vec[1],
|
|
1092
|
+
acc[0],
|
|
1093
|
+
true
|
|
1094
|
+
);
|
|
1095
|
+
|
|
1096
|
+
#elif defined(RDNA3)
|
|
1097
|
+
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
1098
|
+
int32x4_t * a_vec = (int32x4_t *) A.x;
|
|
1099
|
+
int32x4_t * b_vec = (int32x4_t *) B.x;
|
|
1100
|
+
|
|
1101
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1102
|
+
true,
|
|
1103
|
+
a_vec[0],
|
|
1104
|
+
true,
|
|
1105
|
+
b_vec[0],
|
|
1106
|
+
acc[0],
|
|
1107
|
+
true
|
|
1108
|
+
);
|
|
1109
|
+
|
|
1110
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1111
|
+
true,
|
|
1112
|
+
a_vec[1],
|
|
1113
|
+
true,
|
|
1114
|
+
b_vec[1],
|
|
1115
|
+
acc[0],
|
|
1116
|
+
true
|
|
1117
|
+
);
|
|
1118
|
+
#endif // RDNA4
|
|
1119
|
+
|
|
518
1120
|
#else
|
|
519
1121
|
GGML_UNUSED_VARS(D, A, B);
|
|
520
1122
|
NO_DEVICE_CODE;
|
|
@@ -541,9 +1143,100 @@ namespace ggml_cuda_mma {
|
|
|
541
1143
|
acc[0],
|
|
542
1144
|
0, 0, 0);
|
|
543
1145
|
#endif // defined(CDNA3)
|
|
1146
|
+
|
|
544
1147
|
#else
|
|
545
1148
|
GGML_UNUSED_VARS(D, A, B);
|
|
546
1149
|
NO_DEVICE_CODE;
|
|
547
1150
|
#endif // AMD_MFMA_AVAILABLE
|
|
548
1151
|
}
|
|
1152
|
+
|
|
1153
|
+
template <typename T1, typename T2, int J, int K>
|
|
1154
|
+
static __device__ __forceinline__ void mma(
|
|
1155
|
+
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
|
|
1156
|
+
tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
|
|
1157
|
+
const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
|
|
1158
|
+
mma(D16[0], A16[0], B);
|
|
1159
|
+
mma(D16[1], A16[1], B);
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
static __device__ __forceinline__ void mma(
|
|
1163
|
+
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
|
1164
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
1165
|
+
const int * Axi = (const int *) A.x;
|
|
1166
|
+
const int * Bxi = (const int *) B.x;
|
|
1167
|
+
int * Dxi = (int *) D.x;
|
|
1168
|
+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
|
1169
|
+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
1170
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
1171
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
1172
|
+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
|
1173
|
+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
|
1174
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
|
1175
|
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
|
1176
|
+
#else
|
|
1177
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1178
|
+
NO_DEVICE_CODE;
|
|
1179
|
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1182
|
+
static __device__ __forceinline__ void mma(
|
|
1183
|
+
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
|
|
1184
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
1185
|
+
const int * Axi = (const int *) A.x;
|
|
1186
|
+
const int * Bxi = (const int *) B.x;
|
|
1187
|
+
int * Dxi = (int *) D.x;
|
|
1188
|
+
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
|
1189
|
+
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
|
1190
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
1191
|
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
|
1192
|
+
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
|
1193
|
+
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
|
1194
|
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
|
1195
|
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
|
1196
|
+
#else
|
|
1197
|
+
GGML_UNUSED_VARS(D, A, B);
|
|
1198
|
+
NO_DEVICE_CODE;
|
|
1199
|
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
1200
|
+
}
|
|
1201
|
+
|
|
1202
|
+
template <data_layout dl_d, data_layout dl_ab>
|
|
1203
|
+
static __device__ __forceinline__ void mma(
|
|
1204
|
+
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
|
|
1205
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1206
|
+
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
|
1207
|
+
int32x8_t * acc = (int32x8_t *) D.x;
|
|
1208
|
+
#if defined(RDNA4)
|
|
1209
|
+
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
|
1210
|
+
int32x2_t * a_vec = (int32x2_t *) A.x;
|
|
1211
|
+
int32x2_t * b_vec = (int32x2_t *) B.x;
|
|
1212
|
+
|
|
1213
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
1214
|
+
true,
|
|
1215
|
+
a_vec[0],
|
|
1216
|
+
true,
|
|
1217
|
+
b_vec[0],
|
|
1218
|
+
acc[0],
|
|
1219
|
+
false
|
|
1220
|
+
);
|
|
1221
|
+
#elif defined(RDNA3)
|
|
1222
|
+
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
|
1223
|
+
int32x4_t * a_vec = (int32x4_t *) A.x;
|
|
1224
|
+
int32x4_t * b_vec = (int32x4_t *) B.x;
|
|
1225
|
+
|
|
1226
|
+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
|
1227
|
+
true,
|
|
1228
|
+
a_vec[0],
|
|
1229
|
+
true,
|
|
1230
|
+
b_vec[0],
|
|
1231
|
+
acc[0],
|
|
1232
|
+
false
|
|
1233
|
+
);
|
|
1234
|
+
#endif // RDNA4
|
|
1235
|
+
#else
|
|
1236
|
+
GGML_UNUSED(D);
|
|
1237
|
+
GGML_UNUSED(A);
|
|
1238
|
+
GGML_UNUSED(B);
|
|
1239
|
+
NO_DEVICE_CODE;
|
|
1240
|
+
#endif // AMD_WMMA_AVAILABLE
|
|
1241
|
+
}
|
|
549
1242
|
}
|