whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -21,10 +21,12 @@
|
|
|
21
21
|
#include "ggml-common.h"
|
|
22
22
|
|
|
23
23
|
#include <array>
|
|
24
|
+
#include <algorithm>
|
|
24
25
|
#include <cassert>
|
|
25
26
|
#include <cfloat>
|
|
26
27
|
#include <cstdio>
|
|
27
28
|
#include <string>
|
|
29
|
+
#include <unordered_map>
|
|
28
30
|
#include <vector>
|
|
29
31
|
|
|
30
32
|
#if defined(GGML_USE_HIP)
|
|
@@ -48,6 +50,10 @@
|
|
|
48
50
|
#define GGML_CUDA_CC_TURING 750
|
|
49
51
|
#define GGML_CUDA_CC_AMPERE 800
|
|
50
52
|
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
|
53
|
+
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
|
|
54
|
+
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
|
|
55
|
+
#define GGML_CUDA_CC_BLACKWELL 1200
|
|
56
|
+
#define GGML_CUDA_CC_RUBIN 1300
|
|
51
57
|
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
|
52
58
|
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
|
|
53
59
|
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
|
|
@@ -65,31 +71,34 @@
|
|
|
65
71
|
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
|
66
72
|
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
|
67
73
|
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
|
74
|
+
#define GGML_CUDA_CC_RDNA3_5 (GGML_CUDA_CC_OFFSET_AMD + 0x1150) // AI 370, AI Max 395 laptops.
|
|
68
75
|
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
|
69
76
|
|
|
70
|
-
#define GGML_CUDA_CC_IS_AMD(cc)
|
|
71
|
-
#define GGML_CUDA_CC_IS_RDNA(cc)
|
|
72
|
-
#define GGML_CUDA_CC_IS_RDNA1(cc)
|
|
73
|
-
#define GGML_CUDA_CC_IS_RDNA2(cc)
|
|
74
|
-
#define
|
|
75
|
-
#define
|
|
76
|
-
#define
|
|
77
|
-
#define
|
|
78
|
-
#define
|
|
79
|
-
#define
|
|
80
|
-
#define
|
|
77
|
+
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
|
|
78
|
+
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
|
79
|
+
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
|
80
|
+
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
|
81
|
+
#define GGML_CUDA_CC_IS_RDNA3_0(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA3_5)
|
|
82
|
+
#define GGML_CUDA_CC_IS_RDNA3_5(cc) (cc >= GGML_CUDA_CC_RDNA3_5 && cc < GGML_CUDA_CC_RDNA4)
|
|
83
|
+
#define GGML_CUDA_CC_IS_RDNA3(cc) (GGML_CUDA_CC_IS_RDNA3_0(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc))
|
|
84
|
+
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
|
85
|
+
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
|
|
86
|
+
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
|
|
87
|
+
#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
|
|
88
|
+
#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
|
|
89
|
+
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
|
|
81
90
|
|
|
82
91
|
// Moore Threads
|
|
83
92
|
#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons
|
|
84
93
|
|
|
85
94
|
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
|
|
86
95
|
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
|
|
87
|
-
#define
|
|
96
|
+
#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
|
|
88
97
|
|
|
89
98
|
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
|
|
90
99
|
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
|
|
91
|
-
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc <
|
|
92
|
-
#define
|
|
100
|
+
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
|
|
101
|
+
#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
|
|
93
102
|
|
|
94
103
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
|
95
104
|
# define GGML_CUDA_USE_CUB
|
|
@@ -212,26 +221,27 @@ static const char * cu_get_error_str(CUresult err) {
|
|
|
212
221
|
#define GGML_USE_VMM
|
|
213
222
|
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
|
|
214
223
|
|
|
215
|
-
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
|
224
|
+
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
|
216
225
|
#define FP16_AVAILABLE
|
|
217
|
-
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
|
226
|
+
#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
|
218
227
|
|
|
219
228
|
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
|
220
229
|
#define FAST_FP16_AVAILABLE
|
|
221
230
|
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
|
222
231
|
|
|
223
|
-
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
|
224
|
-
#define FP16_MMA_AVAILABLE
|
|
225
|
-
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
|
226
|
-
|
|
227
|
-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
|
228
|
-
#define FP16_MMA_AVAILABLE
|
|
229
|
-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
|
230
|
-
|
|
231
232
|
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
|
232
233
|
#define AMD_MFMA_AVAILABLE
|
|
233
234
|
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
|
234
235
|
|
|
236
|
+
#if defined(GGML_USE_HIP) && (defined(RDNA4) || defined(RDNA3))
|
|
237
|
+
#define AMD_WMMA_AVAILABLE
|
|
238
|
+
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
|
|
239
|
+
|
|
240
|
+
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
|
241
|
+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
242
|
+
#define VOLTA_MMA_AVAILABLE
|
|
243
|
+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
244
|
+
|
|
235
245
|
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
|
236
246
|
#define TURING_MMA_AVAILABLE
|
|
237
247
|
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
|
@@ -240,6 +250,10 @@ static const char * cu_get_error_str(CUresult err) {
|
|
|
240
250
|
#define AMPERE_MMA_AVAILABLE
|
|
241
251
|
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
242
252
|
|
|
253
|
+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
|
|
254
|
+
# define BLACKWELL_MMA_AVAILABLE
|
|
255
|
+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
|
|
256
|
+
|
|
243
257
|
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
244
258
|
#define CP_ASYNC_AVAILABLE
|
|
245
259
|
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
@@ -249,11 +263,14 @@ static const char * cu_get_error_str(CUresult err) {
|
|
|
249
263
|
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
|
|
250
264
|
|
|
251
265
|
static bool fp16_available(const int cc) {
|
|
252
|
-
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL
|
|
266
|
+
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
|
|
267
|
+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
|
|
253
268
|
}
|
|
254
269
|
|
|
255
270
|
static bool fast_fp16_available(const int cc) {
|
|
256
|
-
return (
|
|
271
|
+
return GGML_CUDA_CC_IS_AMD(cc) ||
|
|
272
|
+
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
|
|
273
|
+
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
|
|
257
274
|
}
|
|
258
275
|
|
|
259
276
|
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
|
@@ -262,27 +279,6 @@ static bool fast_fp16_hardware_available(const int cc) {
|
|
|
262
279
|
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
|
|
263
280
|
}
|
|
264
281
|
|
|
265
|
-
// Any FP16 tensor core instructions are available for ggml code.
|
|
266
|
-
static bool fp16_mma_available(const int cc) {
|
|
267
|
-
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
|
268
|
-
return false;
|
|
269
|
-
#else
|
|
270
|
-
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
|
271
|
-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
|
|
272
|
-
GGML_CUDA_CC_IS_MTHREADS(cc)) {
|
|
273
|
-
return true;
|
|
274
|
-
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
|
275
|
-
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
|
276
|
-
return true;
|
|
277
|
-
#else
|
|
278
|
-
return false;
|
|
279
|
-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
|
280
|
-
} else {
|
|
281
|
-
return false;
|
|
282
|
-
}
|
|
283
|
-
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
|
284
|
-
}
|
|
285
|
-
|
|
286
282
|
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
|
287
283
|
static bool fp16_mma_hardware_available(const int cc) {
|
|
288
284
|
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
|
|
@@ -291,7 +287,9 @@ static bool fp16_mma_hardware_available(const int cc) {
|
|
|
291
287
|
}
|
|
292
288
|
|
|
293
289
|
static bool bf16_mma_hardware_available(const int cc) {
|
|
294
|
-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
|
|
290
|
+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
|
|
291
|
+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
|
|
292
|
+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
|
|
295
293
|
}
|
|
296
294
|
|
|
297
295
|
static bool fp32_mma_hardware_available(const int cc) {
|
|
@@ -306,7 +304,14 @@ static bool amd_mfma_available(const int cc) {
|
|
|
306
304
|
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
|
307
305
|
}
|
|
308
306
|
|
|
309
|
-
|
|
307
|
+
static bool amd_wmma_available(const int cc) {
|
|
308
|
+
return (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc));
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
static bool volta_mma_available(const int cc) {
|
|
312
|
+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
|
313
|
+
}
|
|
314
|
+
|
|
310
315
|
static bool turing_mma_available(const int cc) {
|
|
311
316
|
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
|
312
317
|
}
|
|
@@ -319,6 +324,11 @@ static bool cp_async_available(const int cc) {
|
|
|
319
324
|
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
|
320
325
|
}
|
|
321
326
|
|
|
327
|
+
static bool blackwell_mma_available(const int cc) {
|
|
328
|
+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
|
|
329
|
+
ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
|
|
330
|
+
}
|
|
331
|
+
|
|
322
332
|
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
|
323
333
|
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
|
|
324
334
|
return 64;
|
|
@@ -469,6 +479,53 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
|
469
479
|
return x;
|
|
470
480
|
}
|
|
471
481
|
|
|
482
|
+
template<typename T, int width = WARP_SIZE>
|
|
483
|
+
static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
|
|
484
|
+
const int lane_id = threadIdx.x % width;
|
|
485
|
+
#pragma unroll
|
|
486
|
+
for (int offset = 1; offset < width; offset <<= 1) {
|
|
487
|
+
const T t = __shfl_up_sync(0xffffffff, x, offset, width);
|
|
488
|
+
if (lane_id >= offset) {
|
|
489
|
+
x += t;
|
|
490
|
+
}
|
|
491
|
+
}
|
|
492
|
+
return x;
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
template<int width = WARP_SIZE>
|
|
496
|
+
static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
|
|
497
|
+
const int lane_id = threadIdx.x % width;
|
|
498
|
+
#pragma unroll
|
|
499
|
+
for (int offset = 1; offset < width; offset <<= 1) {
|
|
500
|
+
const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
|
|
501
|
+
const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
|
|
502
|
+
if (lane_id >= offset) {
|
|
503
|
+
a.x += t_x;
|
|
504
|
+
a.y += t_y;
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
return a;
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
template<int width = WARP_SIZE>
|
|
511
|
+
static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
|
|
512
|
+
#ifdef FP16_AVAILABLE
|
|
513
|
+
const int lane_id = threadIdx.x % width;
|
|
514
|
+
#pragma unroll
|
|
515
|
+
for (int offset = 1; offset < width; offset <<= 1) {
|
|
516
|
+
const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
|
|
517
|
+
if (lane_id >= offset) {
|
|
518
|
+
a = __hadd2(a, t);
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
return a;
|
|
522
|
+
|
|
523
|
+
#else
|
|
524
|
+
NO_DEVICE_CODE;
|
|
525
|
+
return a;
|
|
526
|
+
#endif // FP16_AVAILABLE
|
|
527
|
+
}
|
|
528
|
+
|
|
472
529
|
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
|
473
530
|
#ifdef FP16_AVAILABLE
|
|
474
531
|
|
|
@@ -570,8 +627,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
|
|
|
570
627
|
acc += v.y*u.y;
|
|
571
628
|
}
|
|
572
629
|
|
|
573
|
-
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
|
574
630
|
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
|
631
|
+
#define V_DOT2_F32_F16_AVAILABLE
|
|
632
|
+
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
|
633
|
+
|
|
634
|
+
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
|
635
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
575
636
|
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
|
576
637
|
#else
|
|
577
638
|
#ifdef FAST_FP16_AVAILABLE
|
|
@@ -583,7 +644,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
|
|
|
583
644
|
acc += tmpv.x * tmpu.x;
|
|
584
645
|
acc += tmpv.y * tmpu.y;
|
|
585
646
|
#endif // FAST_FP16_AVAILABLE
|
|
586
|
-
#endif //
|
|
647
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
587
648
|
}
|
|
588
649
|
|
|
589
650
|
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
|
|
@@ -600,8 +661,18 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
|
|
|
600
661
|
}
|
|
601
662
|
|
|
602
663
|
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
|
|
664
|
+
// Important: do not use this function if dst and src both point at registers.
|
|
665
|
+
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
|
|
666
|
+
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
|
|
667
|
+
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
|
|
603
668
|
template <int nbytes, int alignment = 0>
|
|
604
669
|
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
|
670
|
+
static_assert(
|
|
671
|
+
nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0,
|
|
672
|
+
"You are misusing the alignment parameter for ggml_cuda_memcpy_1. "
|
|
673
|
+
"The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. "
|
|
674
|
+
"If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. "
|
|
675
|
+
"Call ggml_cuda_memcpy_1 in a loop instead.");
|
|
605
676
|
if constexpr (alignment != 0) {
|
|
606
677
|
static_assert(nbytes % alignment == 0, "bad alignment");
|
|
607
678
|
}
|
|
@@ -643,14 +714,39 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
|
|
643
714
|
#endif // CUDART_VERSION >= 12050
|
|
644
715
|
}
|
|
645
716
|
|
|
717
|
+
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
|
718
|
+
const uint8_t sign_bit = (x < 0.0f) << 3;
|
|
719
|
+
float ax = fabsf(x) * e;
|
|
720
|
+
|
|
721
|
+
// Positive LUT
|
|
722
|
+
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
|
|
723
|
+
|
|
724
|
+
int best_i = 0;
|
|
725
|
+
float best_err = fabsf(ax - pos_lut[0]);
|
|
726
|
+
|
|
727
|
+
#pragma unroll
|
|
728
|
+
for (int i = 1; i < 8; ++i) {
|
|
729
|
+
const float err = fabsf(ax - pos_lut[i]);
|
|
730
|
+
if (err < best_err) {
|
|
731
|
+
best_err = err;
|
|
732
|
+
best_i = i;
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
return static_cast<uint8_t>(best_i | sign_bit);
|
|
737
|
+
}
|
|
738
|
+
|
|
646
739
|
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
|
647
740
|
// Precompute mp (m' in the paper) and L such that division
|
|
648
741
|
// can be computed using a multiply (high 32b of 64b result)
|
|
649
742
|
// and a shift:
|
|
650
743
|
//
|
|
651
744
|
// n/d = (mulhi(n, mp) + n) >> L;
|
|
652
|
-
static const uint3 init_fastdiv_values(
|
|
653
|
-
GGML_ASSERT(
|
|
745
|
+
static const uint3 init_fastdiv_values(uint64_t d_64) {
|
|
746
|
+
GGML_ASSERT(d_64 != 0);
|
|
747
|
+
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
|
|
748
|
+
|
|
749
|
+
uint32_t d = (uint32_t)d_64;
|
|
654
750
|
|
|
655
751
|
// compute L = ceil(log2(d));
|
|
656
752
|
uint32_t L = 0;
|
|
@@ -854,15 +950,16 @@ struct ggml_cuda_device_info {
|
|
|
854
950
|
int device_count;
|
|
855
951
|
|
|
856
952
|
struct cuda_device_info {
|
|
857
|
-
int cc;
|
|
858
|
-
int nsm;
|
|
859
|
-
size_t smpb;
|
|
860
|
-
size_t smpbo;
|
|
861
|
-
bool integrated;
|
|
862
|
-
bool vmm;
|
|
863
|
-
size_t vmm_granularity;
|
|
953
|
+
int cc; // compute capability
|
|
954
|
+
int nsm; // number of streaming multiprocessors
|
|
955
|
+
size_t smpb; // max. shared memory per block
|
|
956
|
+
size_t smpbo; // max. shared memory per block (with opt-in)
|
|
957
|
+
bool integrated; // Device is integrated as opposed to discrete
|
|
958
|
+
bool vmm; // virtual memory support
|
|
959
|
+
size_t vmm_granularity; // granularity of virtual memory
|
|
864
960
|
size_t total_vram;
|
|
865
|
-
int warp_size;
|
|
961
|
+
int warp_size; // Number of threads in a dispatch
|
|
962
|
+
bool supports_cooperative_launch; // whether cooperative launch is supported
|
|
866
963
|
};
|
|
867
964
|
|
|
868
965
|
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
|
|
@@ -939,7 +1036,7 @@ struct ggml_tensor_extra_gpu {
|
|
|
939
1036
|
#define USE_CUDA_GRAPH
|
|
940
1037
|
#endif
|
|
941
1038
|
|
|
942
|
-
struct
|
|
1039
|
+
struct ggml_cuda_graph_node_properties {
|
|
943
1040
|
void * node_address;
|
|
944
1041
|
ggml_op node_op;
|
|
945
1042
|
int64_t ne[GGML_MAX_DIMS];
|
|
@@ -962,22 +1059,181 @@ struct ggml_cuda_graph {
|
|
|
962
1059
|
cudaGraphExec_t instance = nullptr;
|
|
963
1060
|
size_t num_nodes = 0;
|
|
964
1061
|
std::vector<cudaGraphNode_t> nodes;
|
|
965
|
-
std::vector<cudaKernelNodeParams> params;
|
|
966
1062
|
bool disable_due_to_gpu_arch = false;
|
|
967
1063
|
bool disable_due_to_too_many_updates = false;
|
|
968
|
-
bool disable_due_to_failed_graph_capture = false;
|
|
969
1064
|
int number_consecutive_updates = 0;
|
|
970
|
-
std::vector<
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
1065
|
+
std::vector<ggml_cuda_graph_node_properties> props;
|
|
1066
|
+
|
|
1067
|
+
void record_update(bool use_graph, bool update_required) {
|
|
1068
|
+
if (use_graph && update_required) {
|
|
1069
|
+
number_consecutive_updates++;
|
|
1070
|
+
} else {
|
|
1071
|
+
number_consecutive_updates = 0;
|
|
1072
|
+
}
|
|
1073
|
+
if (number_consecutive_updates >= 4) {
|
|
1074
|
+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
|
1075
|
+
disable_due_to_too_many_updates = true;
|
|
1076
|
+
}
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1079
|
+
bool is_enabled() const {
|
|
1080
|
+
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
|
1081
|
+
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
|
|
1082
|
+
}
|
|
978
1083
|
#endif
|
|
979
1084
|
};
|
|
980
1085
|
|
|
1086
|
+
struct ggml_cuda_concurrent_event {
|
|
1087
|
+
std::vector<cudaEvent_t> join_events;
|
|
1088
|
+
cudaEvent_t fork_event = nullptr;
|
|
1089
|
+
|
|
1090
|
+
int n_streams = 0;
|
|
1091
|
+
std::unordered_map<const ggml_tensor *, int> stream_mapping;
|
|
1092
|
+
|
|
1093
|
+
// Original order of nodes in this concurrent region (before interleaving)
|
|
1094
|
+
// Used to restore grouping for fusion within streams
|
|
1095
|
+
std::vector<const ggml_tensor *> original_order;
|
|
1096
|
+
|
|
1097
|
+
const ggml_tensor * join_node;
|
|
1098
|
+
|
|
1099
|
+
ggml_cuda_concurrent_event() = default;
|
|
1100
|
+
|
|
1101
|
+
ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
|
|
1102
|
+
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
|
|
1103
|
+
|
|
1104
|
+
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
|
|
1105
|
+
join_events.resize(n_streams);
|
|
1106
|
+
|
|
1107
|
+
for (size_t i = 0; i < join_events.size(); ++i) {
|
|
1108
|
+
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
|
|
1109
|
+
}
|
|
1110
|
+
|
|
1111
|
+
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
|
|
1112
|
+
}
|
|
1113
|
+
|
|
1114
|
+
ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
|
|
1115
|
+
: join_events(std::move(other.join_events))
|
|
1116
|
+
, fork_event(other.fork_event)
|
|
1117
|
+
, n_streams(other.n_streams)
|
|
1118
|
+
, stream_mapping(std::move(other.stream_mapping))
|
|
1119
|
+
, original_order(std::move(other.original_order))
|
|
1120
|
+
, join_node(other.join_node) {
|
|
1121
|
+
other.fork_event = nullptr;
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
// 1. check if any branches write to overlapping memory ranges (except the join node)
|
|
1125
|
+
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
|
|
1126
|
+
// we assume all nodes have the same buffer
|
|
1127
|
+
bool is_valid() const {
|
|
1128
|
+
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
|
|
1129
|
+
write_ranges.resize(n_streams);
|
|
1130
|
+
|
|
1131
|
+
// get join_node's memory range to exclude from overlap checking.
|
|
1132
|
+
// multiple nodes can use join_node's buffer; we synchronize on the join node.
|
|
1133
|
+
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
|
|
1134
|
+
const int64_t join_start = (int64_t) join_t->data;
|
|
1135
|
+
const int64_t join_end = join_start + ggml_nbytes(join_t);
|
|
1136
|
+
|
|
1137
|
+
for (const auto & [tensor, stream] : stream_mapping) {
|
|
1138
|
+
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
|
|
1139
|
+
const int64_t t_start = (int64_t) t->data;
|
|
1140
|
+
const int64_t t_end = t_start + ggml_nbytes(t);
|
|
1141
|
+
|
|
1142
|
+
// skip tensors that overlap with join_node's buffer.
|
|
1143
|
+
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
|
|
1144
|
+
continue;
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
// concurrent streams begin from 1
|
|
1148
|
+
write_ranges[stream - 1].emplace_back(t_start, t_end);
|
|
1149
|
+
}
|
|
1150
|
+
|
|
1151
|
+
for (int i = 0; i < n_streams; ++i) {
|
|
1152
|
+
// sorts first by start then by end of write range
|
|
1153
|
+
std::sort(write_ranges[i].begin(), write_ranges[i].end());
|
|
1154
|
+
}
|
|
1155
|
+
|
|
1156
|
+
bool writes_overlap = false;
|
|
1157
|
+
bool dependent_srcs = false;
|
|
1158
|
+
for (const auto & [tensor, stream] : stream_mapping) {
|
|
1159
|
+
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
|
|
1160
|
+
const int64_t t_start = (int64_t) t->data;
|
|
1161
|
+
const int64_t t_end = t_start + ggml_nbytes(t);
|
|
1162
|
+
|
|
1163
|
+
// skip tensors that overlap with join_node's buffer
|
|
1164
|
+
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
|
|
1165
|
+
continue;
|
|
1166
|
+
}
|
|
1167
|
+
|
|
1168
|
+
// check if this buffer's write data overlaps with another stream's
|
|
1169
|
+
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
|
|
1170
|
+
for (int i = 0; i < n_streams; ++i) {
|
|
1171
|
+
if (i == stream - 1) {
|
|
1172
|
+
continue;
|
|
1173
|
+
}
|
|
1174
|
+
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
|
|
1175
|
+
|
|
1176
|
+
if (it != write_ranges[i].end()) {
|
|
1177
|
+
const std::pair<int64_t, int64_t> & other = *it;
|
|
1178
|
+
|
|
1179
|
+
// std::lower_bound returns the first element where other >= data_range (lexicographically).
|
|
1180
|
+
// This guarantees other.first >= data_range.first.
|
|
1181
|
+
// Therefore, overlap occurs iff other.first < data_range.second
|
|
1182
|
+
// (i.e., the other range starts before this range ends).
|
|
1183
|
+
if (other.first < data_range.second) {
|
|
1184
|
+
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
|
|
1185
|
+
writes_overlap = true;
|
|
1186
|
+
break;
|
|
1187
|
+
}
|
|
1188
|
+
}
|
|
1189
|
+
}
|
|
1190
|
+
|
|
1191
|
+
//check if all srcs are either in branch or don't have a branch
|
|
1192
|
+
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
|
1193
|
+
if (!tensor->src[i]) {
|
|
1194
|
+
continue;
|
|
1195
|
+
}
|
|
1196
|
+
|
|
1197
|
+
auto it = stream_mapping.find(tensor->src[i]);
|
|
1198
|
+
|
|
1199
|
+
if (it == stream_mapping.end()) {
|
|
1200
|
+
continue;
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
if (it->second != stream) {
|
|
1204
|
+
dependent_srcs = true;
|
|
1205
|
+
break;
|
|
1206
|
+
}
|
|
1207
|
+
}
|
|
1208
|
+
|
|
1209
|
+
if (dependent_srcs || writes_overlap) {
|
|
1210
|
+
break;
|
|
1211
|
+
}
|
|
1212
|
+
}
|
|
1213
|
+
|
|
1214
|
+
return !writes_overlap && !dependent_srcs;
|
|
1215
|
+
}
|
|
1216
|
+
|
|
1217
|
+
~ggml_cuda_concurrent_event() {
|
|
1218
|
+
if (fork_event != nullptr) {
|
|
1219
|
+
CUDA_CHECK(cudaEventDestroy(fork_event));
|
|
1220
|
+
}
|
|
1221
|
+
for (cudaEvent_t e : join_events) {
|
|
1222
|
+
if (e != nullptr) {
|
|
1223
|
+
CUDA_CHECK(cudaEventDestroy(e));
|
|
1224
|
+
}
|
|
1225
|
+
}
|
|
1226
|
+
}
|
|
1227
|
+
};
|
|
1228
|
+
|
|
1229
|
+
struct ggml_cuda_stream_context {
|
|
1230
|
+
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
|
|
1231
|
+
|
|
1232
|
+
void reset() {
|
|
1233
|
+
concurrent_events.clear();
|
|
1234
|
+
}
|
|
1235
|
+
};
|
|
1236
|
+
|
|
981
1237
|
struct ggml_backend_cuda_context {
|
|
982
1238
|
int device;
|
|
983
1239
|
std::string name;
|
|
@@ -988,11 +1244,15 @@ struct ggml_backend_cuda_context {
|
|
|
988
1244
|
|
|
989
1245
|
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
|
990
1246
|
|
|
1247
|
+
int curr_stream_no = 0;
|
|
1248
|
+
|
|
991
1249
|
explicit ggml_backend_cuda_context(int device) :
|
|
992
1250
|
device(device),
|
|
993
1251
|
name(GGML_CUDA_NAME + std::to_string(device)) {
|
|
994
1252
|
}
|
|
995
1253
|
|
|
1254
|
+
ggml_cuda_stream_context concurrent_stream_context;
|
|
1255
|
+
|
|
996
1256
|
~ggml_backend_cuda_context();
|
|
997
1257
|
|
|
998
1258
|
cudaStream_t stream(int device, int stream) {
|
|
@@ -1003,9 +1263,9 @@ struct ggml_backend_cuda_context {
|
|
|
1003
1263
|
return streams[device][stream];
|
|
1004
1264
|
}
|
|
1005
1265
|
|
|
1006
|
-
cudaStream_t stream() {
|
|
1007
|
-
|
|
1008
|
-
}
|
|
1266
|
+
cudaStream_t stream() { return stream(device, curr_stream_no); }
|
|
1267
|
+
|
|
1268
|
+
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
|
|
1009
1269
|
|
|
1010
1270
|
cublasHandle_t cublas_handle(int device) {
|
|
1011
1271
|
if (cublas_handles[device] == nullptr) {
|
|
@@ -1021,18 +1281,31 @@ struct ggml_backend_cuda_context {
|
|
|
1021
1281
|
}
|
|
1022
1282
|
|
|
1023
1283
|
// pool
|
|
1024
|
-
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
|
|
1284
|
+
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
|
|
1025
1285
|
|
|
1026
|
-
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
|
|
1286
|
+
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
|
|
1027
1287
|
|
|
1028
1288
|
ggml_cuda_pool & pool(int device) {
|
|
1029
|
-
if (pools[device] == nullptr) {
|
|
1030
|
-
pools[device] = new_pool_for_device(device);
|
|
1289
|
+
if (pools[device][curr_stream_no] == nullptr) {
|
|
1290
|
+
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
|
|
1031
1291
|
}
|
|
1032
|
-
return *pools[device];
|
|
1292
|
+
return *pools[device][curr_stream_no];
|
|
1033
1293
|
}
|
|
1034
1294
|
|
|
1035
1295
|
ggml_cuda_pool & pool() {
|
|
1036
1296
|
return pool(device);
|
|
1037
1297
|
}
|
|
1038
1298
|
};
|
|
1299
|
+
|
|
1300
|
+
struct ggml_cuda_mm_fusion_args_host {
|
|
1301
|
+
const ggml_tensor * x_bias = nullptr;
|
|
1302
|
+
const ggml_tensor * gate = nullptr;
|
|
1303
|
+
const ggml_tensor * gate_bias = nullptr;
|
|
1304
|
+
ggml_glu_op glu_op;
|
|
1305
|
+
};
|
|
1306
|
+
struct ggml_cuda_mm_fusion_args_device {
|
|
1307
|
+
const void * x_bias = nullptr;
|
|
1308
|
+
const void * gate = nullptr;
|
|
1309
|
+
const void * gate_bias = nullptr;
|
|
1310
|
+
ggml_glu_op glu_op;
|
|
1311
|
+
};
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#pragma once
|
|
1
2
|
#include "common.cuh"
|
|
2
3
|
|
|
3
4
|
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
|
@@ -38,6 +39,15 @@ template<typename dst_t, typename src_t>
|
|
|
38
39
|
return __float2bfloat16(float(x));
|
|
39
40
|
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
|
40
41
|
return __bfloat162float(x);
|
|
42
|
+
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
|
|
43
|
+
return __float22half2_rn(x);
|
|
44
|
+
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
|
|
45
|
+
// bypass compile error on cuda 12.0.1
|
|
46
|
+
#ifdef GGML_USE_HIP
|
|
47
|
+
return __float22bfloat162_rn(x);
|
|
48
|
+
#else
|
|
49
|
+
return {x.x, x.y};
|
|
50
|
+
#endif // GGML_USE_HIP
|
|
41
51
|
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
|
42
52
|
return int32_t(x);
|
|
43
53
|
} else {
|
|
@@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
|
|
212
212
|
}
|
|
213
213
|
|
|
214
214
|
template<typename src_t, typename dst_t>
|
|
215
|
-
static __device__ void
|
|
215
|
+
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
|
|
216
216
|
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
|
217
217
|
}
|