whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -7,19 +7,21 @@
|
|
|
7
7
|
|
|
8
8
|
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
|
9
9
|
|
|
10
|
+
const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
|
|
11
|
+
const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
|
|
12
|
+
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
|
13
|
+
|
|
10
14
|
template <cpy_kernel_t cpy_1>
|
|
11
|
-
static __global__ void
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
15
|
+
static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
|
|
16
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
17
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
|
18
|
+
const int64_t nb12, const int64_t nb13) {
|
|
19
|
+
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
|
16
20
|
|
|
17
21
|
if (i >= ne) {
|
|
18
22
|
return;
|
|
19
23
|
}
|
|
20
24
|
|
|
21
|
-
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
|
22
|
-
|
|
23
25
|
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
|
24
26
|
// then combine those indices with the corresponding byte offsets to get the total offsets
|
|
25
27
|
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
|
@@ -37,6 +39,61 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne
|
|
|
37
39
|
cpy_1(cx + x_offset, cdst + dst_offset);
|
|
38
40
|
}
|
|
39
41
|
|
|
42
|
+
template <typename T>
|
|
43
|
+
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
|
|
44
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
45
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
|
46
|
+
const int64_t nb12, const int64_t nb13) {
|
|
47
|
+
|
|
48
|
+
const T* src = reinterpret_cast<const T*>(cx);
|
|
49
|
+
T* dst = reinterpret_cast<T*>(cdst);
|
|
50
|
+
|
|
51
|
+
const int64_t nmat = ne / (ne00 * ne01);
|
|
52
|
+
const int64_t n = ne00 * ne01;
|
|
53
|
+
|
|
54
|
+
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
|
|
55
|
+
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
|
56
|
+
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
|
|
57
|
+
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
|
58
|
+
|
|
59
|
+
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
|
|
60
|
+
int cur_tile_buf = 0;
|
|
61
|
+
|
|
62
|
+
#pragma unroll
|
|
63
|
+
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
|
|
64
|
+
|
|
65
|
+
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
|
|
66
|
+
if (imat >= nmat)
|
|
67
|
+
break;
|
|
68
|
+
|
|
69
|
+
#pragma unroll
|
|
70
|
+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
|
|
71
|
+
if(x < ne01 && y + j < ne00){
|
|
72
|
+
const int row = threadIdx.y+j;
|
|
73
|
+
const int col = threadIdx.x * sizeof(float)/sizeof(T);
|
|
74
|
+
T *tile2 = reinterpret_cast<T*>(tile[cur_tile_buf][row]);
|
|
75
|
+
tile2[col] = src[imat*n + (y+j)*ne01 + x];
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
__syncthreads();
|
|
80
|
+
|
|
81
|
+
#pragma unroll
|
|
82
|
+
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
|
|
83
|
+
if (ty + j < ne01 && tx < ne00) {
|
|
84
|
+
const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
|
|
85
|
+
const T *tile2 = reinterpret_cast<const T*>(tile[cur_tile_buf][threadIdx.x]);
|
|
86
|
+
dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
cur_tile_buf = (cur_tile_buf + 1) % 2;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
|
|
94
|
+
nb12, nb13);
|
|
95
|
+
}
|
|
96
|
+
|
|
40
97
|
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
|
41
98
|
float * cdstf = (float *)(cdsti);
|
|
42
99
|
|
|
@@ -63,227 +120,262 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
|
|
63
120
|
}
|
|
64
121
|
|
|
65
122
|
template <cpy_kernel_t cpy_blck, int qk>
|
|
66
|
-
static __global__ void cpy_f32_q(const char * cx, char *
|
|
67
|
-
const
|
|
68
|
-
const
|
|
69
|
-
const
|
|
70
|
-
const
|
|
123
|
+
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
|
|
124
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
125
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
|
126
|
+
const int64_t nb12, const int64_t nb13) {
|
|
127
|
+
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
|
71
128
|
|
|
72
129
|
if (i >= ne) {
|
|
73
130
|
return;
|
|
74
131
|
}
|
|
75
132
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
const
|
|
79
|
-
const
|
|
80
|
-
const
|
|
81
|
-
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
|
82
|
-
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
|
133
|
+
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
|
134
|
+
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
|
135
|
+
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
136
|
+
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
|
137
|
+
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
|
83
138
|
|
|
84
|
-
const
|
|
85
|
-
const
|
|
86
|
-
const
|
|
87
|
-
const
|
|
88
|
-
const
|
|
139
|
+
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
|
140
|
+
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
|
141
|
+
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
|
142
|
+
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
|
143
|
+
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
|
89
144
|
|
|
90
145
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
|
91
146
|
}
|
|
92
147
|
|
|
93
148
|
template <cpy_kernel_t cpy_blck, int qk>
|
|
94
|
-
static __global__ void cpy_q_f32(const char * cx, char *
|
|
95
|
-
const
|
|
96
|
-
const
|
|
97
|
-
const
|
|
98
|
-
const
|
|
149
|
+
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
|
|
150
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
151
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
|
152
|
+
const int64_t nb12, const int64_t nb13) {
|
|
153
|
+
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
|
99
154
|
|
|
100
155
|
if (i >= ne) {
|
|
101
156
|
return;
|
|
102
157
|
}
|
|
103
158
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
const
|
|
107
|
-
const
|
|
108
|
-
const
|
|
109
|
-
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
|
110
|
-
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
|
159
|
+
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
|
160
|
+
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
|
161
|
+
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
162
|
+
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
|
163
|
+
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
|
111
164
|
|
|
112
|
-
const
|
|
113
|
-
const
|
|
114
|
-
const
|
|
115
|
-
const
|
|
116
|
-
const
|
|
165
|
+
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
|
166
|
+
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
|
167
|
+
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
|
168
|
+
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
|
169
|
+
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
|
117
170
|
|
|
118
171
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
|
119
172
|
}
|
|
120
173
|
|
|
121
|
-
|
|
174
|
+
template<typename src_t, typename dst_t>
|
|
175
|
+
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
|
176
|
+
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
|
122
177
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
|
126
|
-
CUDA_CHECK(cudaStreamSynchronize(stream));
|
|
127
|
-
if (cuda_graph->dest_ptrs_d != nullptr) {
|
|
128
|
-
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
|
|
129
|
-
}
|
|
130
|
-
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
|
|
131
|
-
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
|
|
178
|
+
if (i >= ne) {
|
|
179
|
+
return;
|
|
132
180
|
}
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
#endif
|
|
181
|
+
|
|
182
|
+
const src_t * x = (const src_t *) cx;
|
|
183
|
+
dst_t * dst = (dst_t *) cdst;
|
|
184
|
+
|
|
185
|
+
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
|
|
139
186
|
}
|
|
140
187
|
|
|
141
188
|
template<typename src_t, typename dst_t>
|
|
142
|
-
static void
|
|
143
|
-
const char * cx, char * cdst, const
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
(cx, cdst, ne
|
|
189
|
+
static void ggml_cpy_scalar_contiguous_cuda(
|
|
190
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
191
|
+
cudaStream_t stream) {
|
|
192
|
+
|
|
193
|
+
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
194
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
195
|
+
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
196
|
+
(cx, cdst, ne);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
template<typename src_t, typename dst_t, bool transposed = false>
|
|
200
|
+
static void ggml_cpy_scalar_cuda(
|
|
201
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
202
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
203
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
|
204
|
+
|
|
205
|
+
if (transposed) {
|
|
206
|
+
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
|
|
207
|
+
int64_t ne00n, ne01n, ne02n;
|
|
208
|
+
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
|
|
209
|
+
ne00n = ne00;
|
|
210
|
+
ne01n = ne01;
|
|
211
|
+
ne02n = ne02;
|
|
212
|
+
} else {
|
|
213
|
+
ne00n = ne00;
|
|
214
|
+
ne01n = ne01*ne02;
|
|
215
|
+
ne02n = 1;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
|
|
219
|
+
int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
|
|
220
|
+
int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
|
|
221
|
+
GGML_ASSERT(grid_x < UINT_MAX);
|
|
222
|
+
GGML_ASSERT(grid_y < USHRT_MAX);
|
|
223
|
+
GGML_ASSERT(grid_z < USHRT_MAX);
|
|
224
|
+
dim3 dimGrid(grid_x, grid_y, grid_z);
|
|
225
|
+
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
|
226
|
+
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
|
227
|
+
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
228
|
+
} else {
|
|
229
|
+
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
230
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
231
|
+
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
232
|
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
233
|
+
}
|
|
150
234
|
}
|
|
151
235
|
|
|
152
236
|
static void ggml_cpy_f32_q8_0_cuda(
|
|
153
|
-
const char * cx, char * cdst, const
|
|
154
|
-
const
|
|
155
|
-
const
|
|
237
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
238
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
239
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
|
156
240
|
|
|
157
241
|
GGML_ASSERT(ne % QK8_0 == 0);
|
|
158
|
-
const
|
|
242
|
+
const int64_t num_blocks = ne / QK8_0;
|
|
243
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
159
244
|
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
|
160
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
245
|
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
161
246
|
}
|
|
162
247
|
|
|
163
248
|
static void ggml_cpy_q8_0_f32_cuda(
|
|
164
|
-
const char * cx, char * cdst, const
|
|
165
|
-
const
|
|
166
|
-
const
|
|
249
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
250
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
251
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
|
167
252
|
|
|
168
|
-
const
|
|
253
|
+
const int64_t num_blocks = ne;
|
|
254
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
169
255
|
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
|
170
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
256
|
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
171
257
|
}
|
|
172
258
|
|
|
173
259
|
static void ggml_cpy_f32_q4_0_cuda(
|
|
174
|
-
const char * cx, char * cdst, const
|
|
175
|
-
const
|
|
176
|
-
const
|
|
260
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
261
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
262
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
|
177
263
|
|
|
178
264
|
GGML_ASSERT(ne % QK4_0 == 0);
|
|
179
|
-
const
|
|
265
|
+
const int64_t num_blocks = ne / QK4_0;
|
|
266
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
180
267
|
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
|
181
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
268
|
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
182
269
|
}
|
|
183
270
|
|
|
184
271
|
static void ggml_cpy_q4_0_f32_cuda(
|
|
185
|
-
const char * cx, char * cdst, const
|
|
186
|
-
const
|
|
187
|
-
const
|
|
188
|
-
const
|
|
189
|
-
const
|
|
190
|
-
cudaStream_t stream
|
|
191
|
-
const
|
|
272
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
273
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
|
274
|
+
const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
275
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
|
|
276
|
+
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
|
277
|
+
cudaStream_t stream) {
|
|
278
|
+
const int64_t num_blocks = ne;
|
|
279
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
192
280
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
|
193
281
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
|
194
|
-
ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
282
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
195
283
|
}
|
|
196
284
|
|
|
197
285
|
static void ggml_cpy_f32_q4_1_cuda(
|
|
198
|
-
const char * cx, char * cdst, const
|
|
199
|
-
const
|
|
200
|
-
const
|
|
286
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
287
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
288
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
|
201
289
|
|
|
202
290
|
GGML_ASSERT(ne % QK4_1 == 0);
|
|
203
|
-
const
|
|
291
|
+
const int64_t num_blocks = ne / QK4_1;
|
|
292
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
204
293
|
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
|
205
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
294
|
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
206
295
|
}
|
|
207
296
|
|
|
208
297
|
static void ggml_cpy_q4_1_f32_cuda(
|
|
209
|
-
const char * cx, char * cdst, const
|
|
210
|
-
const
|
|
211
|
-
const
|
|
212
|
-
const
|
|
213
|
-
const
|
|
214
|
-
cudaStream_t stream
|
|
215
|
-
const
|
|
298
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
299
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
|
300
|
+
const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
301
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
|
|
302
|
+
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
|
303
|
+
cudaStream_t stream) {
|
|
304
|
+
const int64_t num_blocks = ne;
|
|
305
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
216
306
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
|
217
307
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
|
218
|
-
ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
308
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
219
309
|
}
|
|
220
310
|
|
|
221
311
|
static void ggml_cpy_f32_q5_0_cuda(
|
|
222
|
-
const char * cx, char * cdst, const
|
|
223
|
-
const
|
|
224
|
-
const
|
|
312
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
313
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
314
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
|
225
315
|
|
|
226
316
|
GGML_ASSERT(ne % QK5_0 == 0);
|
|
227
|
-
const
|
|
317
|
+
const int64_t num_blocks = ne / QK5_0;
|
|
318
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
228
319
|
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
|
229
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
320
|
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
230
321
|
}
|
|
231
322
|
|
|
232
323
|
static void ggml_cpy_q5_0_f32_cuda(
|
|
233
|
-
const char * cx, char * cdst, const
|
|
234
|
-
const
|
|
235
|
-
const
|
|
236
|
-
const
|
|
237
|
-
const
|
|
238
|
-
cudaStream_t stream
|
|
239
|
-
const
|
|
324
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
325
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
|
326
|
+
const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
327
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
|
|
328
|
+
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
|
329
|
+
cudaStream_t stream) {
|
|
330
|
+
const int64_t num_blocks = ne;
|
|
331
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
240
332
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
|
241
333
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
|
242
|
-
ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
334
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
243
335
|
}
|
|
244
336
|
|
|
245
337
|
static void ggml_cpy_f32_q5_1_cuda(
|
|
246
|
-
const char * cx, char * cdst, const
|
|
247
|
-
const
|
|
248
|
-
const
|
|
338
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
339
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
340
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
|
249
341
|
|
|
250
342
|
GGML_ASSERT(ne % QK5_1 == 0);
|
|
251
|
-
const
|
|
343
|
+
const int64_t num_blocks = ne / QK5_1;
|
|
344
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
252
345
|
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
|
253
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
346
|
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
254
347
|
}
|
|
255
348
|
|
|
256
349
|
static void ggml_cpy_q5_1_f32_cuda(
|
|
257
|
-
const char * cx, char * cdst, const
|
|
258
|
-
const
|
|
259
|
-
const
|
|
260
|
-
const
|
|
261
|
-
const
|
|
262
|
-
cudaStream_t stream
|
|
263
|
-
const
|
|
350
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
351
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
|
352
|
+
const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
353
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
|
|
354
|
+
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
|
355
|
+
cudaStream_t stream) {
|
|
356
|
+
const int64_t num_blocks = ne;
|
|
357
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
264
358
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
|
265
359
|
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
|
266
|
-
ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
360
|
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
267
361
|
}
|
|
268
362
|
|
|
269
363
|
static void ggml_cpy_f32_iq4_nl_cuda(
|
|
270
|
-
const char * cx, char * cdst, const
|
|
271
|
-
const
|
|
272
|
-
const
|
|
364
|
+
const char * cx, char * cdst, const int64_t ne,
|
|
365
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
|
366
|
+
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
|
273
367
|
|
|
274
368
|
GGML_ASSERT(ne % QK4_NL == 0);
|
|
275
|
-
const
|
|
369
|
+
const int64_t num_blocks = ne / QK4_NL;
|
|
370
|
+
GGML_ASSERT(num_blocks < UINT_MAX);
|
|
276
371
|
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
|
277
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13
|
|
372
|
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
278
373
|
}
|
|
279
374
|
|
|
280
|
-
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1
|
|
375
|
+
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
|
281
376
|
const int64_t ne = ggml_nelements(src0);
|
|
282
377
|
GGML_ASSERT(ne == ggml_nelements(src1));
|
|
283
378
|
|
|
284
|
-
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
|
|
285
|
-
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
|
|
286
|
-
|
|
287
379
|
const int64_t ne00 = src0->ne[0];
|
|
288
380
|
const int64_t ne01 = src0->ne[1];
|
|
289
381
|
const int64_t ne02 = src0->ne[2];
|
|
@@ -311,17 +403,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
311
403
|
char * src0_ddc = (char *) src0->data;
|
|
312
404
|
char * src1_ddc = (char *) src1->data;
|
|
313
405
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
|
320
|
-
}
|
|
321
|
-
#else
|
|
322
|
-
GGML_UNUSED(disable_indirection_for_this_node);
|
|
323
|
-
#endif
|
|
324
|
-
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
|
406
|
+
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
|
407
|
+
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
|
408
|
+
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
|
409
|
+
|
|
410
|
+
if (src0->type == src1->type && contiguous_srcs) {
|
|
325
411
|
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
|
326
412
|
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
|
327
413
|
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
|
|
@@ -329,134 +415,144 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
329
415
|
} else
|
|
330
416
|
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
|
331
417
|
{
|
|
332
|
-
|
|
333
|
-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
334
|
-
} else {
|
|
335
|
-
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
|
336
|
-
}
|
|
418
|
+
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
|
337
419
|
}
|
|
338
420
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
339
|
-
|
|
421
|
+
if (can_be_transposed) {
|
|
422
|
+
ggml_cpy_scalar_cuda<float, float, true>
|
|
423
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
424
|
+
} else {
|
|
425
|
+
ggml_cpy_scalar_cuda<float, float>
|
|
426
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
427
|
+
}
|
|
340
428
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
341
|
-
|
|
429
|
+
if (contiguous_srcs) {
|
|
430
|
+
ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
|
|
431
|
+
(src0_ddc, src1_ddc, ne, main_stream);
|
|
432
|
+
} else {
|
|
433
|
+
ggml_cpy_scalar_cuda<float, nv_bfloat16>
|
|
434
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
435
|
+
}
|
|
342
436
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
343
|
-
|
|
437
|
+
if (contiguous_srcs) {
|
|
438
|
+
ggml_cpy_scalar_contiguous_cuda<float, half>
|
|
439
|
+
(src0_ddc, src1_ddc, ne, main_stream);
|
|
440
|
+
} else {
|
|
441
|
+
ggml_cpy_scalar_cuda<float, half>
|
|
442
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
443
|
+
}
|
|
344
444
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
345
|
-
ggml_cpy_f32_q8_0_cuda
|
|
445
|
+
ggml_cpy_f32_q8_0_cuda
|
|
446
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
346
447
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
|
347
|
-
ggml_cpy_q8_0_f32_cuda
|
|
448
|
+
ggml_cpy_q8_0_f32_cuda
|
|
449
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
348
450
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
|
349
|
-
ggml_cpy_f32_q4_0_cuda
|
|
451
|
+
ggml_cpy_f32_q4_0_cuda
|
|
452
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
350
453
|
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
|
351
|
-
ggml_cpy_q4_0_f32_cuda
|
|
352
|
-
|
|
454
|
+
ggml_cpy_q4_0_f32_cuda
|
|
455
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
353
456
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
354
|
-
ggml_cpy_f32_q4_1_cuda
|
|
457
|
+
ggml_cpy_f32_q4_1_cuda
|
|
458
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
355
459
|
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
|
356
|
-
ggml_cpy_q4_1_f32_cuda
|
|
357
|
-
|
|
460
|
+
ggml_cpy_q4_1_f32_cuda
|
|
461
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
358
462
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
|
359
|
-
ggml_cpy_f32_q5_0_cuda
|
|
463
|
+
ggml_cpy_f32_q5_0_cuda
|
|
464
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
360
465
|
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
|
361
|
-
ggml_cpy_q5_0_f32_cuda
|
|
362
|
-
|
|
466
|
+
ggml_cpy_q5_0_f32_cuda
|
|
467
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
363
468
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
|
364
|
-
ggml_cpy_f32_iq4_nl_cuda
|
|
469
|
+
ggml_cpy_f32_iq4_nl_cuda
|
|
470
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
365
471
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
|
366
|
-
ggml_cpy_f32_q5_1_cuda
|
|
472
|
+
ggml_cpy_f32_q5_1_cuda
|
|
473
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
367
474
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
|
368
|
-
ggml_cpy_q5_1_f32_cuda
|
|
475
|
+
ggml_cpy_q5_1_f32_cuda
|
|
476
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
369
477
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
370
|
-
|
|
478
|
+
if (can_be_transposed) {
|
|
479
|
+
ggml_cpy_scalar_cuda<half, half, true>
|
|
480
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
481
|
+
} else {
|
|
482
|
+
ggml_cpy_scalar_cuda<half, half>
|
|
483
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
484
|
+
}
|
|
371
485
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
|
372
|
-
|
|
486
|
+
if (contiguous_srcs) {
|
|
487
|
+
ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
|
|
488
|
+
(src0_ddc, src1_ddc, ne, main_stream);
|
|
489
|
+
} else {
|
|
490
|
+
ggml_cpy_scalar_cuda<half, nv_bfloat16>
|
|
491
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
492
|
+
}
|
|
373
493
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
374
|
-
|
|
494
|
+
if (contiguous_srcs) {
|
|
495
|
+
ggml_cpy_scalar_contiguous_cuda<half, float>
|
|
496
|
+
(src0_ddc, src1_ddc, ne, main_stream);
|
|
497
|
+
} else {
|
|
498
|
+
ggml_cpy_scalar_cuda<half, float>
|
|
499
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
500
|
+
}
|
|
375
501
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
376
|
-
|
|
502
|
+
if (can_be_transposed) {
|
|
503
|
+
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
|
504
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
505
|
+
} else {
|
|
506
|
+
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
|
|
507
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
508
|
+
}
|
|
377
509
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
|
378
|
-
|
|
510
|
+
if (contiguous_srcs) {
|
|
511
|
+
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
|
|
512
|
+
(src0_ddc, src1_ddc, ne, main_stream);
|
|
513
|
+
} else {
|
|
514
|
+
ggml_cpy_scalar_cuda<nv_bfloat16, half>
|
|
515
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
516
|
+
}
|
|
379
517
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
|
380
|
-
|
|
518
|
+
if (contiguous_srcs) {
|
|
519
|
+
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
|
|
520
|
+
(src0_ddc, src1_ddc, ne, main_stream);
|
|
521
|
+
} else {
|
|
522
|
+
ggml_cpy_scalar_cuda<nv_bfloat16, float>
|
|
523
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
524
|
+
}
|
|
525
|
+
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
|
526
|
+
if (can_be_transposed) {
|
|
527
|
+
ggml_cpy_scalar_cuda<int32_t, int32_t, true>
|
|
528
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
529
|
+
} else {
|
|
530
|
+
ggml_cpy_scalar_cuda<int32_t, int32_t>
|
|
531
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
532
|
+
}
|
|
381
533
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
|
382
|
-
|
|
534
|
+
if (contiguous_srcs) {
|
|
535
|
+
ggml_cpy_scalar_contiguous_cuda<float, int32_t>
|
|
536
|
+
(src0_ddc, src1_ddc, ne, main_stream);
|
|
537
|
+
} else {
|
|
538
|
+
ggml_cpy_scalar_cuda<float, int32_t>
|
|
539
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
540
|
+
}
|
|
383
541
|
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
|
384
|
-
|
|
542
|
+
if (contiguous_srcs) {
|
|
543
|
+
ggml_cpy_scalar_contiguous_cuda<int32_t, float>
|
|
544
|
+
(src0_ddc, src1_ddc, ne, main_stream);
|
|
545
|
+
} else {
|
|
546
|
+
ggml_cpy_scalar_cuda<int32_t, float>
|
|
547
|
+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
548
|
+
}
|
|
385
549
|
} else {
|
|
386
550
|
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
387
551
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
388
552
|
}
|
|
389
|
-
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
|
390
|
-
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
|
391
|
-
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
|
392
|
-
}
|
|
393
|
-
#else
|
|
394
|
-
GGML_UNUSED(disable_indirection_for_this_node);
|
|
395
|
-
#endif
|
|
396
|
-
|
|
397
553
|
}
|
|
398
554
|
|
|
399
555
|
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
400
556
|
const ggml_tensor * src0 = dst->src[0];
|
|
401
|
-
|
|
402
|
-
ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
|
|
403
|
-
}
|
|
404
|
-
|
|
405
|
-
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
406
|
-
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
|
407
|
-
// Prioritize CUDA graph compatibility over direct memory copy optimization.
|
|
408
|
-
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
|
|
409
|
-
if (src0->type == GGML_TYPE_F32) {
|
|
410
|
-
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
|
411
|
-
} else {
|
|
412
|
-
return nullptr;
|
|
413
|
-
}
|
|
414
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
415
|
-
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
|
416
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
417
|
-
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
|
|
418
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
419
|
-
return (void*) cpy_flt<cpy_1_flt<float, half>>;
|
|
420
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
421
|
-
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
|
422
|
-
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
|
423
|
-
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
|
424
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
|
425
|
-
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
|
426
|
-
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
|
427
|
-
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
|
428
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
429
|
-
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
|
430
|
-
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
|
431
|
-
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
|
432
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
|
433
|
-
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
|
434
|
-
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
|
435
|
-
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
|
436
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
|
437
|
-
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
|
438
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
|
439
|
-
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
|
440
|
-
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
|
441
|
-
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
|
442
|
-
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
443
|
-
return (void*) cpy_flt<cpy_1_flt<half, half>>;
|
|
444
|
-
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
|
445
|
-
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
|
|
446
|
-
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
447
|
-
return (void*) cpy_flt<cpy_1_flt<half, float>>;
|
|
448
|
-
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
|
449
|
-
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
|
|
450
|
-
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
451
|
-
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
|
452
|
-
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
|
453
|
-
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
|
454
|
-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
|
455
|
-
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
|
|
456
|
-
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
|
457
|
-
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
|
|
458
|
-
} else {
|
|
459
|
-
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
460
|
-
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
461
|
-
}
|
|
557
|
+
ggml_cuda_cpy(ctx, src0, dst);
|
|
462
558
|
}
|