whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
#include "ggml.h"
|
|
2
2
|
#include "common.cuh"
|
|
3
|
-
#include "
|
|
3
|
+
#include "unary.cuh"
|
|
4
4
|
#include "mmvf.cuh"
|
|
5
|
+
#include "convert.cuh"
|
|
5
6
|
|
|
6
|
-
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
|
7
|
+
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
|
|
7
8
|
static __global__ void mul_mat_vec_f(
|
|
8
|
-
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
|
9
|
+
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
|
9
10
|
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
|
10
|
-
const
|
|
11
|
-
const
|
|
11
|
+
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
12
|
+
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
|
12
13
|
const int row = blockIdx.x;
|
|
13
14
|
const int channel_dst = blockIdx.y;
|
|
14
|
-
const int channel_x = ids ? ids[channel_dst] : channel_dst
|
|
15
|
+
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
|
|
15
16
|
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
|
16
17
|
const int sample_dst = blockIdx.z;
|
|
17
|
-
const int sample_x = sample_dst
|
|
18
|
+
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
|
|
18
19
|
const int sample_y = sample_dst;
|
|
19
20
|
const int tid = threadIdx.x;
|
|
20
21
|
|
|
@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
|
|
|
24
25
|
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
|
25
26
|
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
|
26
27
|
|
|
28
|
+
bool use_gate = false;
|
|
29
|
+
bool use_bias = false;
|
|
30
|
+
bool use_gate_bias = false;
|
|
31
|
+
ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
|
|
32
|
+
const T * gate_x = nullptr;
|
|
33
|
+
const float * x_bias = nullptr;
|
|
34
|
+
const float * gate_bias = nullptr;
|
|
35
|
+
|
|
36
|
+
if constexpr (has_fusion) {
|
|
37
|
+
use_gate = fusion.gate != nullptr;
|
|
38
|
+
use_bias = fusion.x_bias != nullptr;
|
|
39
|
+
use_gate_bias = fusion.gate_bias != nullptr;
|
|
40
|
+
glu_op = fusion.glu_op;
|
|
41
|
+
|
|
42
|
+
if (use_gate) {
|
|
43
|
+
gate_x = static_cast<const T *>(fusion.gate);
|
|
44
|
+
}
|
|
45
|
+
if (use_bias) {
|
|
46
|
+
x_bias = static_cast<const float *>(fusion.x_bias);
|
|
47
|
+
}
|
|
48
|
+
if (use_gate_bias) {
|
|
49
|
+
gate_bias = static_cast<const float *>(fusion.gate_bias);
|
|
50
|
+
use_gate_bias = use_gate;
|
|
51
|
+
} else {
|
|
52
|
+
use_gate_bias = false;
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
if (use_gate) {
|
|
57
|
+
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
|
58
|
+
}
|
|
59
|
+
if constexpr (has_fusion) {
|
|
60
|
+
const int channel_bias = ids ? channel_x : channel_dst;
|
|
61
|
+
if (use_bias) {
|
|
62
|
+
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
|
63
|
+
}
|
|
64
|
+
if (use_gate_bias) {
|
|
65
|
+
gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
27
69
|
const float2 * y2 = (const float2 *) y;
|
|
28
70
|
|
|
29
71
|
extern __shared__ char data_mmv[];
|
|
30
72
|
float * buf_iw = (float *) data_mmv;
|
|
73
|
+
float * buf_iw_gate = nullptr;
|
|
74
|
+
if constexpr (has_fusion) {
|
|
75
|
+
buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
|
|
76
|
+
}
|
|
31
77
|
|
|
32
78
|
if (block_size > warp_size) {
|
|
33
79
|
if (tid < warp_size) {
|
|
34
80
|
buf_iw[tid] = 0.0f;
|
|
81
|
+
if constexpr (has_fusion) {
|
|
82
|
+
if (use_gate) {
|
|
83
|
+
buf_iw_gate[tid] = 0.0f;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
35
86
|
}
|
|
36
87
|
__syncthreads();
|
|
37
88
|
}
|
|
38
89
|
|
|
39
90
|
float sumf[ncols_dst] = {0.0f};
|
|
91
|
+
float sumf_gate[ncols_dst];
|
|
92
|
+
if constexpr (has_fusion) {
|
|
93
|
+
#pragma unroll
|
|
94
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
|
95
|
+
sumf_gate[j] = 0.0f;
|
|
96
|
+
}
|
|
97
|
+
}
|
|
40
98
|
|
|
41
99
|
if constexpr (std::is_same_v<T, float>) {
|
|
42
100
|
const float2 * x2 = (const float2 *) x;
|
|
101
|
+
const float2 * gate_x2 = nullptr;
|
|
102
|
+
if constexpr (has_fusion) {
|
|
103
|
+
if (use_gate) {
|
|
104
|
+
gate_x2 = (const float2 *) gate_x;
|
|
105
|
+
}
|
|
106
|
+
}
|
|
43
107
|
|
|
44
108
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
45
109
|
const float2 tmpx = x2[col2];
|
|
110
|
+
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
|
111
|
+
if constexpr (has_fusion) {
|
|
112
|
+
if (use_gate) {
|
|
113
|
+
tmpx_gate = gate_x2[col2];
|
|
114
|
+
}
|
|
115
|
+
}
|
|
46
116
|
|
|
47
117
|
#pragma unroll
|
|
48
118
|
for (int j = 0; j < ncols_dst; ++j) {
|
|
49
119
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
|
50
|
-
sumf[j]
|
|
51
|
-
sumf[j]
|
|
120
|
+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
|
121
|
+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
|
122
|
+
|
|
123
|
+
if constexpr (has_fusion) {
|
|
124
|
+
if (use_gate) {
|
|
125
|
+
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
|
126
|
+
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
|
127
|
+
}
|
|
128
|
+
}
|
|
52
129
|
}
|
|
53
130
|
}
|
|
54
131
|
} else if constexpr (std::is_same_v<T, half>) {
|
|
55
132
|
const half2 * x2 = (const half2 *) x;
|
|
133
|
+
const half2 * gate_x2 = nullptr;
|
|
134
|
+
if constexpr (has_fusion) {
|
|
135
|
+
if (use_gate) {
|
|
136
|
+
gate_x2 = (const half2 *) gate_x;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
56
139
|
|
|
57
140
|
if (std::is_same_v<type_acc, float>) {
|
|
58
141
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
59
142
|
const float2 tmpx = __half22float2(x2[col2]);
|
|
60
|
-
|
|
143
|
+
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
|
144
|
+
if constexpr (has_fusion) {
|
|
145
|
+
if (use_gate) {
|
|
146
|
+
tmpx_gate = __half22float2(gate_x2[col2]);
|
|
147
|
+
}
|
|
148
|
+
}
|
|
61
149
|
#pragma unroll
|
|
62
150
|
for (int j = 0; j < ncols_dst; ++j) {
|
|
63
151
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
|
64
|
-
sumf[j]
|
|
65
|
-
sumf[j]
|
|
152
|
+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
|
153
|
+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
|
154
|
+
|
|
155
|
+
if constexpr (has_fusion) {
|
|
156
|
+
if (use_gate) {
|
|
157
|
+
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
|
158
|
+
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
|
159
|
+
}
|
|
160
|
+
}
|
|
66
161
|
}
|
|
67
162
|
}
|
|
68
163
|
} else {
|
|
69
164
|
#ifdef FP16_AVAILABLE
|
|
70
165
|
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
|
166
|
+
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
|
|
71
167
|
|
|
72
168
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
73
169
|
const half2 tmpx = x2[col2];
|
|
74
|
-
|
|
170
|
+
half2 tmpx_gate = make_half2(0.0f, 0.0f);
|
|
171
|
+
if constexpr (has_fusion) {
|
|
172
|
+
if (use_gate) {
|
|
173
|
+
tmpx_gate = gate_x2[col2];
|
|
174
|
+
}
|
|
175
|
+
}
|
|
75
176
|
#pragma unroll
|
|
76
177
|
for (int j = 0; j < ncols_dst; ++j) {
|
|
77
178
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
|
78
179
|
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
|
180
|
+
|
|
181
|
+
if constexpr (has_fusion) {
|
|
182
|
+
if (use_gate) {
|
|
183
|
+
sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
79
186
|
}
|
|
80
187
|
}
|
|
81
188
|
|
|
@@ -83,21 +190,86 @@ static __global__ void mul_mat_vec_f(
|
|
|
83
190
|
for (int j = 0; j < ncols_dst; ++j) {
|
|
84
191
|
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
|
85
192
|
}
|
|
193
|
+
|
|
194
|
+
if constexpr (has_fusion) {
|
|
195
|
+
if (use_gate) {
|
|
196
|
+
#pragma unroll
|
|
197
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
|
198
|
+
sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
}
|
|
86
202
|
#else
|
|
87
203
|
NO_DEVICE_CODE;
|
|
88
204
|
#endif // FP16_AVAILABLE
|
|
89
205
|
}
|
|
90
206
|
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
|
207
|
+
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
|
208
|
+
#if defined(GGML_USE_HIP)
|
|
91
209
|
const int * x2 = (const int *) x;
|
|
210
|
+
const int * gate_x2 = nullptr;
|
|
211
|
+
if constexpr (has_fusion) {
|
|
212
|
+
if (use_gate) {
|
|
213
|
+
gate_x2 = (const int *) gate_x;
|
|
214
|
+
}
|
|
215
|
+
}
|
|
92
216
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
93
217
|
const int tmpx = x2[col2];
|
|
218
|
+
int tmpx_gate = 0;
|
|
219
|
+
if constexpr (has_fusion) {
|
|
220
|
+
if (use_gate) {
|
|
221
|
+
tmpx_gate = gate_x2[col2];
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
#pragma unroll
|
|
225
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
|
226
|
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
|
227
|
+
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
|
|
228
|
+
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
|
229
|
+
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
|
230
|
+
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
|
231
|
+
|
|
232
|
+
if constexpr (has_fusion) {
|
|
233
|
+
if (use_gate) {
|
|
234
|
+
const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
|
|
235
|
+
const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
|
|
236
|
+
ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
|
|
237
|
+
ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
#else
|
|
243
|
+
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
|
244
|
+
const nv_bfloat162 * gate_x2 = nullptr;
|
|
245
|
+
if constexpr (has_fusion) {
|
|
246
|
+
if (use_gate) {
|
|
247
|
+
gate_x2 = (const nv_bfloat162 *) gate_x;
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
251
|
+
const nv_bfloat162 tmpx = x2[col2];
|
|
252
|
+
nv_bfloat162 tmpx_gate;
|
|
253
|
+
if constexpr (has_fusion) {
|
|
254
|
+
if (use_gate) {
|
|
255
|
+
tmpx_gate = gate_x2[col2];
|
|
256
|
+
}
|
|
257
|
+
}
|
|
94
258
|
#pragma unroll
|
|
95
259
|
for (int j = 0; j < ncols_dst; ++j) {
|
|
96
260
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
|
97
|
-
sumf[j]
|
|
98
|
-
sumf[j]
|
|
261
|
+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
|
262
|
+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
|
263
|
+
|
|
264
|
+
if constexpr (has_fusion) {
|
|
265
|
+
if (use_gate) {
|
|
266
|
+
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
|
267
|
+
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
|
268
|
+
}
|
|
269
|
+
}
|
|
99
270
|
}
|
|
100
271
|
}
|
|
272
|
+
#endif
|
|
101
273
|
} else {
|
|
102
274
|
static_assert(std::is_same_v<T, void>, "unsupported type");
|
|
103
275
|
}
|
|
@@ -106,13 +278,31 @@ static __global__ void mul_mat_vec_f(
|
|
|
106
278
|
for (int j = 0; j < ncols_dst; ++j) {
|
|
107
279
|
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
|
108
280
|
|
|
281
|
+
if constexpr (has_fusion) {
|
|
282
|
+
if (use_gate) {
|
|
283
|
+
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
109
287
|
if (block_size > warp_size) {
|
|
110
288
|
buf_iw[tid/warp_size] = sumf[j];
|
|
289
|
+
if constexpr (has_fusion) {
|
|
290
|
+
if (use_gate) {
|
|
291
|
+
buf_iw_gate[tid/warp_size] = sumf_gate[j];
|
|
292
|
+
}
|
|
293
|
+
}
|
|
111
294
|
__syncthreads();
|
|
112
295
|
if (tid < warp_size) {
|
|
113
296
|
sumf[j] = buf_iw[tid];
|
|
114
297
|
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
|
298
|
+
if constexpr (has_fusion) {
|
|
299
|
+
if (use_gate) {
|
|
300
|
+
sumf_gate[j] = buf_iw_gate[tid];
|
|
301
|
+
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
|
302
|
+
}
|
|
303
|
+
}
|
|
115
304
|
}
|
|
305
|
+
|
|
116
306
|
if (j < ncols_dst) {
|
|
117
307
|
__syncthreads();
|
|
118
308
|
}
|
|
@@ -123,12 +313,74 @@ static __global__ void mul_mat_vec_f(
|
|
|
123
313
|
return;
|
|
124
314
|
}
|
|
125
315
|
|
|
126
|
-
|
|
316
|
+
float value = sumf[tid];
|
|
317
|
+
|
|
318
|
+
if constexpr (has_fusion) {
|
|
319
|
+
if (use_bias) {
|
|
320
|
+
value += x_bias[tid*stride_col_dst + row];
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
if (use_gate) {
|
|
324
|
+
float gate_value = sumf_gate[tid];
|
|
325
|
+
if (use_gate_bias) {
|
|
326
|
+
gate_value += gate_bias[tid*stride_col_dst + row];
|
|
327
|
+
}
|
|
328
|
+
switch (glu_op) {
|
|
329
|
+
case GGML_GLU_OP_SWIGLU:
|
|
330
|
+
value *= ggml_cuda_op_silu_single(gate_value);
|
|
331
|
+
break;
|
|
332
|
+
case GGML_GLU_OP_GEGLU:
|
|
333
|
+
value *= ggml_cuda_op_gelu_single(gate_value);
|
|
334
|
+
break;
|
|
335
|
+
case GGML_GLU_OP_SWIGLU_OAI: {
|
|
336
|
+
value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
|
|
337
|
+
break;
|
|
338
|
+
}
|
|
339
|
+
default:
|
|
340
|
+
break;
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
dst[tid*stride_col_dst + row] = value;
|
|
346
|
+
|
|
347
|
+
if constexpr (!has_fusion) {
|
|
348
|
+
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
template<typename T, typename type_acc, int ncols_dst, int block_size>
|
|
353
|
+
static void mul_mat_vec_f_switch_fusion(
|
|
354
|
+
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
|
355
|
+
const int64_t ncols, const int64_t nrows,
|
|
356
|
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
357
|
+
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
358
|
+
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
|
359
|
+
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
|
|
360
|
+
|
|
361
|
+
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
|
362
|
+
if constexpr (ncols_dst == 1) {
|
|
363
|
+
if (has_fusion) {
|
|
364
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
365
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
366
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
367
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
368
|
+
return;
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
|
373
|
+
|
|
374
|
+
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
375
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
376
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
377
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
378
|
+
|
|
127
379
|
}
|
|
128
380
|
|
|
129
381
|
template <typename T, typename type_acc, int ncols_dst>
|
|
130
|
-
|
|
131
|
-
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
382
|
+
void launch_mul_mat_vec_f_cuda(
|
|
383
|
+
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
|
132
384
|
const int64_t ncols, const int64_t nrows,
|
|
133
385
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
134
386
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
|
@@ -140,8 +392,8 @@ static void launch_mul_mat_vec_f_cuda(
|
|
|
140
392
|
GGML_ASSERT(stride_col_y % 2 == 0);
|
|
141
393
|
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
|
142
394
|
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
|
143
|
-
const
|
|
144
|
-
const
|
|
395
|
+
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
|
|
396
|
+
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
|
145
397
|
|
|
146
398
|
const int device = ggml_cuda_get_device();
|
|
147
399
|
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
|
@@ -160,57 +412,59 @@ static void launch_mul_mat_vec_f_cuda(
|
|
|
160
412
|
}
|
|
161
413
|
}
|
|
162
414
|
|
|
163
|
-
const
|
|
415
|
+
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
|
416
|
+
|
|
417
|
+
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
|
|
164
418
|
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
|
165
419
|
const dim3 block_dims(block_size_best, 1, 1);
|
|
166
420
|
switch (block_size_best) {
|
|
167
421
|
case 32: {
|
|
168
|
-
|
|
169
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
170
|
-
|
|
171
|
-
|
|
422
|
+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
|
|
423
|
+
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
424
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
425
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
|
172
426
|
} break;
|
|
173
427
|
case 64: {
|
|
174
|
-
|
|
175
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
176
|
-
|
|
177
|
-
|
|
428
|
+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
|
|
429
|
+
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
430
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
431
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
|
178
432
|
} break;
|
|
179
433
|
case 96: {
|
|
180
|
-
|
|
181
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
182
|
-
|
|
183
|
-
|
|
434
|
+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
|
|
435
|
+
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
436
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
437
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
|
184
438
|
} break;
|
|
185
439
|
case 128: {
|
|
186
|
-
|
|
187
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
188
|
-
|
|
189
|
-
|
|
440
|
+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
|
|
441
|
+
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
442
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
443
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
|
190
444
|
} break;
|
|
191
445
|
case 160: {
|
|
192
|
-
|
|
193
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
194
|
-
|
|
195
|
-
|
|
446
|
+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
|
|
447
|
+
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
448
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
449
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
|
196
450
|
} break;
|
|
197
451
|
case 192: {
|
|
198
|
-
|
|
199
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
200
|
-
|
|
201
|
-
|
|
452
|
+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
|
|
453
|
+
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
454
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
455
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
|
202
456
|
} break;
|
|
203
457
|
case 224: {
|
|
204
|
-
|
|
205
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
206
|
-
|
|
207
|
-
|
|
458
|
+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
|
|
459
|
+
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
460
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
461
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
|
208
462
|
} break;
|
|
209
463
|
case 256: {
|
|
210
|
-
|
|
211
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
212
|
-
|
|
213
|
-
|
|
464
|
+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
|
|
465
|
+
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
466
|
+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
467
|
+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
|
214
468
|
} break;
|
|
215
469
|
default: {
|
|
216
470
|
GGML_ABORT("fatal error");
|
|
@@ -220,7 +474,7 @@ static void launch_mul_mat_vec_f_cuda(
|
|
|
220
474
|
|
|
221
475
|
template <typename T, typename type_acc>
|
|
222
476
|
static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
|
223
|
-
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
477
|
+
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
|
224
478
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
|
225
479
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
226
480
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
|
@@ -230,49 +484,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
|
|
230
484
|
switch (ncols_dst) {
|
|
231
485
|
case 1:
|
|
232
486
|
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
|
233
|
-
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
487
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
234
488
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
235
489
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
236
490
|
break;
|
|
237
491
|
case 2:
|
|
238
492
|
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
|
239
|
-
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
493
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
240
494
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
241
495
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
242
496
|
break;
|
|
243
497
|
case 3:
|
|
244
498
|
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
|
245
|
-
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
499
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
246
500
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
247
501
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
248
502
|
break;
|
|
249
503
|
case 4:
|
|
250
504
|
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
|
251
|
-
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
505
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
252
506
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
253
507
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
254
508
|
break;
|
|
255
509
|
case 5:
|
|
256
510
|
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
|
257
|
-
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
511
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
258
512
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
259
513
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
260
514
|
break;
|
|
261
515
|
case 6:
|
|
262
516
|
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
|
263
|
-
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
517
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
264
518
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
265
519
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
266
520
|
break;
|
|
267
521
|
case 7:
|
|
268
522
|
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
|
269
|
-
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
523
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
270
524
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
271
525
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
272
526
|
break;
|
|
273
527
|
case 8:
|
|
274
528
|
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
|
275
|
-
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
529
|
+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
276
530
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
277
531
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
278
532
|
break;
|
|
@@ -284,29 +538,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
|
|
284
538
|
|
|
285
539
|
template<typename T>
|
|
286
540
|
static void mul_mat_vec_f_cuda(
|
|
287
|
-
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
541
|
+
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
|
288
542
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
|
289
543
|
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
|
290
544
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
|
291
545
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
|
292
546
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
|
293
547
|
enum ggml_prec prec, cudaStream_t stream) {
|
|
548
|
+
|
|
294
549
|
if constexpr(std::is_same_v<T, half>) {
|
|
295
550
|
if (prec == GGML_PREC_DEFAULT) {
|
|
296
551
|
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
|
297
|
-
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
298
|
-
|
|
299
|
-
|
|
552
|
+
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
553
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
554
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
300
555
|
return;
|
|
301
556
|
}
|
|
302
557
|
}
|
|
303
558
|
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
|
304
|
-
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
305
|
-
|
|
306
|
-
|
|
559
|
+
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
560
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
561
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
307
562
|
}
|
|
308
563
|
|
|
309
|
-
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst
|
|
564
|
+
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
|
565
|
+
const ggml_cuda_mm_fusion_args_host * fusion) {
|
|
310
566
|
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
|
311
567
|
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
|
312
568
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
@@ -332,6 +588,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
|
332
588
|
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
|
333
589
|
float * dst_d = (float *) dst->data;
|
|
334
590
|
|
|
591
|
+
ggml_cuda_mm_fusion_args_device fusion_local{};
|
|
592
|
+
|
|
593
|
+
if (fusion) {
|
|
594
|
+
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
|
595
|
+
GGML_ASSERT( ids || dst->ne[1] == 1);
|
|
596
|
+
if (fusion->x_bias) {
|
|
597
|
+
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
|
598
|
+
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
|
599
|
+
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
|
600
|
+
fusion_local.x_bias = fusion->x_bias->data;
|
|
601
|
+
}
|
|
602
|
+
if (fusion->gate) {
|
|
603
|
+
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
|
604
|
+
fusion_local.gate = fusion->gate->data;
|
|
605
|
+
}
|
|
606
|
+
if (fusion->gate_bias) {
|
|
607
|
+
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
|
608
|
+
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
|
609
|
+
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
|
610
|
+
fusion_local.gate_bias = fusion->gate_bias->data;
|
|
611
|
+
}
|
|
612
|
+
fusion_local.glu_op = fusion->glu_op;
|
|
613
|
+
}
|
|
614
|
+
|
|
335
615
|
const int64_t s01 = src0->nb[1] / ts_src0;
|
|
336
616
|
const int64_t s11 = src1->nb[1] / ts_src1;
|
|
337
617
|
const int64_t s1 = dst->nb[1] / ts_dst;
|
|
@@ -354,19 +634,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
|
354
634
|
switch (src0->type) {
|
|
355
635
|
case GGML_TYPE_F32: {
|
|
356
636
|
const float * src0_d = (const float *) src0->data;
|
|
357
|
-
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
637
|
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
358
638
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
|
359
639
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
|
360
640
|
} break;
|
|
361
641
|
case GGML_TYPE_F16: {
|
|
362
642
|
const half * src0_d = (const half *) src0->data;
|
|
363
|
-
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
643
|
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
364
644
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
|
365
645
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
|
366
646
|
} break;
|
|
367
647
|
case GGML_TYPE_BF16: {
|
|
368
648
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
|
369
|
-
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
649
|
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
370
650
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
|
371
651
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
|
372
652
|
} break;
|
|
@@ -393,7 +673,6 @@ void ggml_cuda_op_mul_mat_vec_f(
|
|
|
393
673
|
const int cc = ggml_cuda_info().devices[id].cc;
|
|
394
674
|
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
|
395
675
|
|
|
396
|
-
|
|
397
676
|
// ggml_cuda_op provides single, contiguous matrices
|
|
398
677
|
const int64_t stride_row = ne00;
|
|
399
678
|
const int64_t stride_col_y = ne10;
|
|
@@ -410,22 +689,23 @@ void ggml_cuda_op_mul_mat_vec_f(
|
|
|
410
689
|
const int64_t stride_sample_y = 0;
|
|
411
690
|
const int64_t stride_sample_dst = 0;
|
|
412
691
|
|
|
692
|
+
ggml_cuda_mm_fusion_args_device empty{};
|
|
413
693
|
switch (src0->type) {
|
|
414
694
|
case GGML_TYPE_F32: {
|
|
415
695
|
const float * src0_d = (const float *) src0_dd_i;
|
|
416
|
-
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
696
|
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
417
697
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
418
698
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
|
419
699
|
} break;
|
|
420
700
|
case GGML_TYPE_F16: {
|
|
421
701
|
const half * src0_d = (const half *) src0_dd_i;
|
|
422
|
-
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
702
|
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
423
703
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
424
704
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
|
425
705
|
} break;
|
|
426
706
|
case GGML_TYPE_BF16: {
|
|
427
707
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
|
428
|
-
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
708
|
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
429
709
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
430
710
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
|
431
711
|
} break;
|
|
@@ -436,10 +716,23 @@ void ggml_cuda_op_mul_mat_vec_f(
|
|
|
436
716
|
GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
|
|
437
717
|
}
|
|
438
718
|
|
|
439
|
-
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
|
719
|
+
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
|
|
440
720
|
if (src0_ne[0] % 2 != 0) {
|
|
441
721
|
return false;
|
|
442
722
|
}
|
|
723
|
+
|
|
724
|
+
const size_t ts = ggml_type_size(type);
|
|
725
|
+
if (src0_nb[0] != ts) {
|
|
726
|
+
return false;
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
// Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
|
|
730
|
+
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
|
731
|
+
if (src0_nb[i] % (2*ts) != 0) {
|
|
732
|
+
return false;
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
|
|
443
736
|
switch (type) {
|
|
444
737
|
case GGML_TYPE_F32:
|
|
445
738
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
|
@@ -472,7 +765,10 @@ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0
|
|
|
472
765
|
return ne11 <= 8;
|
|
473
766
|
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
|
474
767
|
if (fp16_mma_hardware_available(cc)) {
|
|
475
|
-
if (GGML_CUDA_CC_IS_RDNA3(cc)
|
|
768
|
+
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
|
769
|
+
return ne11 <= 3;
|
|
770
|
+
}
|
|
771
|
+
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
|
476
772
|
return ne11 <= 5;
|
|
477
773
|
}
|
|
478
774
|
return ne11 <= 2;
|