whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
#include "convert.cuh"
|
|
2
|
+
#include "ggml-cuda/common.cuh"
|
|
3
|
+
#include "ggml.h"
|
|
1
4
|
#include "rope.cuh"
|
|
2
5
|
|
|
3
6
|
struct rope_corr_dims {
|
|
@@ -37,11 +40,23 @@ static __device__ void rope_yarn(
|
|
|
37
40
|
}
|
|
38
41
|
}
|
|
39
42
|
|
|
40
|
-
template<bool forward, bool has_ff, typename T>
|
|
41
|
-
static __global__ void rope_norm(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
43
|
+
template <bool forward, bool has_ff, typename T, typename D>
|
|
44
|
+
static __global__ void rope_norm(const T * x,
|
|
45
|
+
D * dst,
|
|
46
|
+
const int ne0,
|
|
47
|
+
const int ne1,
|
|
48
|
+
const int s1,
|
|
49
|
+
const int s2,
|
|
50
|
+
const int n_dims,
|
|
51
|
+
const int32_t * pos,
|
|
52
|
+
const float freq_scale,
|
|
53
|
+
const float ext_factor,
|
|
54
|
+
const float attn_factor,
|
|
55
|
+
const rope_corr_dims corr_dims,
|
|
56
|
+
const float theta_scale,
|
|
57
|
+
const float * freq_factors,
|
|
58
|
+
const int64_t * row_indices,
|
|
59
|
+
const int set_rows_stride) {
|
|
45
60
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
46
61
|
|
|
47
62
|
if (i0 >= ne0) {
|
|
@@ -53,13 +68,27 @@ static __global__ void rope_norm(
|
|
|
53
68
|
const int row_x = row_dst % ne1;
|
|
54
69
|
const int channel_x = row_dst / ne1;
|
|
55
70
|
|
|
56
|
-
|
|
71
|
+
int idst = row_dst * ne0 + i0;
|
|
57
72
|
const int ix = channel_x*s2 + row_x*s1 + i0;
|
|
58
73
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
74
|
+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
|
75
|
+
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
|
76
|
+
if (set_rows_stride != 0) {
|
|
77
|
+
idst = row_x * ne0 + i0;
|
|
78
|
+
idst += row_indices[channel_x] * set_rows_stride;
|
|
79
|
+
}
|
|
62
80
|
|
|
81
|
+
const auto & store_coaelsced = [&](float x0, float x1) {
|
|
82
|
+
if constexpr (std::is_same_v<float, D>) {
|
|
83
|
+
float2 v = make_float2(x0, x1);
|
|
84
|
+
ggml_cuda_memcpy_1<8>(dst + idst, &v);
|
|
85
|
+
} else if constexpr (std::is_same_v<half, D>) {
|
|
86
|
+
half2 v = make_half2(x0, x1);
|
|
87
|
+
ggml_cuda_memcpy_1<4>(dst + idst, &v);
|
|
88
|
+
}
|
|
89
|
+
};
|
|
90
|
+
if (i0 >= n_dims) {
|
|
91
|
+
store_coaelsced(x[ix + 0], x[ix + 1]);
|
|
63
92
|
return;
|
|
64
93
|
}
|
|
65
94
|
|
|
@@ -75,15 +104,26 @@ static __global__ void rope_norm(
|
|
|
75
104
|
const float x0 = x[ix + 0];
|
|
76
105
|
const float x1 = x[ix + 1];
|
|
77
106
|
|
|
78
|
-
|
|
79
|
-
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
|
|
107
|
+
store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);
|
|
80
108
|
}
|
|
81
109
|
|
|
82
|
-
template<bool forward, bool has_ff, typename T>
|
|
83
|
-
static __global__ void rope_neox(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
110
|
+
template <bool forward, bool has_ff, typename T, typename D>
|
|
111
|
+
static __global__ void rope_neox(const T * x,
|
|
112
|
+
D * dst,
|
|
113
|
+
const int ne0,
|
|
114
|
+
const int ne1,
|
|
115
|
+
const int s1,
|
|
116
|
+
const int s2,
|
|
117
|
+
const int n_dims,
|
|
118
|
+
const int32_t * pos,
|
|
119
|
+
const float freq_scale,
|
|
120
|
+
const float ext_factor,
|
|
121
|
+
const float attn_factor,
|
|
122
|
+
const rope_corr_dims corr_dims,
|
|
123
|
+
const float theta_scale,
|
|
124
|
+
const float * freq_factors,
|
|
125
|
+
const int64_t * row_indices,
|
|
126
|
+
const int set_rows_stride) {
|
|
87
127
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
88
128
|
|
|
89
129
|
if (i0 >= ne0) {
|
|
@@ -95,12 +135,19 @@ static __global__ void rope_neox(
|
|
|
95
135
|
const int row_x = row_dst % ne1;
|
|
96
136
|
const int channel_x = row_dst / ne1;
|
|
97
137
|
|
|
98
|
-
|
|
138
|
+
int idst = row_dst * ne0 + i0 / 2;
|
|
99
139
|
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
|
100
140
|
|
|
141
|
+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
|
142
|
+
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
|
143
|
+
if (set_rows_stride != 0) {
|
|
144
|
+
idst = row_x * ne0 + i0 / 2;
|
|
145
|
+
idst += row_indices[channel_x] * set_rows_stride;
|
|
146
|
+
}
|
|
147
|
+
|
|
101
148
|
if (i0 >= n_dims) {
|
|
102
|
-
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
|
103
|
-
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
|
149
|
+
dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);
|
|
150
|
+
dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);
|
|
104
151
|
|
|
105
152
|
return;
|
|
106
153
|
}
|
|
@@ -117,15 +164,15 @@ static __global__ void rope_neox(
|
|
|
117
164
|
const float x0 = x[ix + 0];
|
|
118
165
|
const float x1 = x[ix + n_dims/2];
|
|
119
166
|
|
|
120
|
-
dst[idst + 0]
|
|
121
|
-
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
167
|
+
dst[idst + 0] = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);
|
|
168
|
+
dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
|
122
169
|
}
|
|
123
170
|
|
|
124
171
|
template<bool forward, bool has_ff, typename T>
|
|
125
172
|
static __global__ void rope_multi(
|
|
126
173
|
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
|
127
174
|
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
|
128
|
-
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
|
175
|
+
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
|
|
129
176
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
130
177
|
|
|
131
178
|
if (i0 >= ne0) {
|
|
@@ -152,17 +199,29 @@ static __global__ void rope_multi(
|
|
|
152
199
|
const int sector = (i0 / 2) % sect_dims;
|
|
153
200
|
|
|
154
201
|
float theta_base = 0.0;
|
|
155
|
-
if (
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
202
|
+
if (is_imrope) {
|
|
203
|
+
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
|
204
|
+
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
|
205
|
+
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
|
206
|
+
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
|
207
|
+
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
|
208
|
+
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
|
209
|
+
} else {
|
|
210
|
+
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
|
211
|
+
}
|
|
212
|
+
} else {
|
|
213
|
+
if (sector < sections.v[0]) {
|
|
214
|
+
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
|
215
|
+
}
|
|
216
|
+
else if (sector >= sections.v[0] && sector < sec_w) {
|
|
217
|
+
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
|
218
|
+
}
|
|
219
|
+
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
|
220
|
+
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
|
221
|
+
}
|
|
222
|
+
else if (sector >= sec_w + sections.v[2]) {
|
|
223
|
+
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
|
224
|
+
}
|
|
166
225
|
}
|
|
167
226
|
|
|
168
227
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
@@ -226,11 +285,25 @@ static __global__ void rope_vision(
|
|
|
226
285
|
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
|
|
227
286
|
}
|
|
228
287
|
|
|
229
|
-
template<bool forward, typename T>
|
|
230
|
-
static void rope_norm_cuda(
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
288
|
+
template <bool forward, typename T, typename D>
|
|
289
|
+
static void rope_norm_cuda(const T * x,
|
|
290
|
+
D * dst,
|
|
291
|
+
const int ne0,
|
|
292
|
+
const int ne1,
|
|
293
|
+
const int s1,
|
|
294
|
+
const int s2,
|
|
295
|
+
const int n_dims,
|
|
296
|
+
const int nr,
|
|
297
|
+
const int32_t * pos,
|
|
298
|
+
const float freq_scale,
|
|
299
|
+
const float freq_base,
|
|
300
|
+
const float ext_factor,
|
|
301
|
+
const float attn_factor,
|
|
302
|
+
const rope_corr_dims corr_dims,
|
|
303
|
+
const float * freq_factors,
|
|
304
|
+
const int64_t * row_indices,
|
|
305
|
+
const int set_rows_stride,
|
|
306
|
+
cudaStream_t stream) {
|
|
234
307
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
235
308
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
236
309
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -240,20 +313,34 @@ static void rope_norm_cuda(
|
|
|
240
313
|
|
|
241
314
|
if (freq_factors == nullptr) {
|
|
242
315
|
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
|
243
|
-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
244
|
-
|
|
316
|
+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
|
317
|
+
freq_factors, row_indices, set_rows_stride);
|
|
245
318
|
} else {
|
|
246
319
|
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
|
247
|
-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
248
|
-
|
|
320
|
+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
|
321
|
+
freq_factors, row_indices, set_rows_stride);
|
|
249
322
|
}
|
|
250
323
|
}
|
|
251
324
|
|
|
252
|
-
template<bool forward, typename T>
|
|
253
|
-
static void rope_neox_cuda(
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
325
|
+
template <bool forward, typename T, typename D>
|
|
326
|
+
static void rope_neox_cuda(const T * x,
|
|
327
|
+
D * dst,
|
|
328
|
+
const int ne0,
|
|
329
|
+
const int ne1,
|
|
330
|
+
const int s1,
|
|
331
|
+
const int s2,
|
|
332
|
+
const int n_dims,
|
|
333
|
+
const int nr,
|
|
334
|
+
const int32_t * pos,
|
|
335
|
+
const float freq_scale,
|
|
336
|
+
const float freq_base,
|
|
337
|
+
const float ext_factor,
|
|
338
|
+
const float attn_factor,
|
|
339
|
+
const rope_corr_dims corr_dims,
|
|
340
|
+
const float * freq_factors,
|
|
341
|
+
const int64_t * row_indices,
|
|
342
|
+
const int set_rows_stride,
|
|
343
|
+
cudaStream_t stream) {
|
|
257
344
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
258
345
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
259
346
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -262,13 +349,13 @@ static void rope_neox_cuda(
|
|
|
262
349
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
263
350
|
|
|
264
351
|
if (freq_factors == nullptr) {
|
|
265
|
-
rope_neox<forward, false
|
|
266
|
-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
267
|
-
|
|
352
|
+
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
|
353
|
+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
|
354
|
+
freq_factors, row_indices, set_rows_stride);
|
|
268
355
|
} else {
|
|
269
|
-
rope_neox<forward, true
|
|
270
|
-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
271
|
-
|
|
356
|
+
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
|
357
|
+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
|
358
|
+
freq_factors, row_indices, set_rows_stride);
|
|
272
359
|
}
|
|
273
360
|
}
|
|
274
361
|
|
|
@@ -276,7 +363,7 @@ template<bool forward, typename T>
|
|
|
276
363
|
static void rope_multi_cuda(
|
|
277
364
|
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
|
278
365
|
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
|
279
|
-
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
|
366
|
+
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
|
|
280
367
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
281
368
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
282
369
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -287,11 +374,11 @@ static void rope_multi_cuda(
|
|
|
287
374
|
if (freq_factors == nullptr) {
|
|
288
375
|
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
|
289
376
|
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
290
|
-
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
377
|
+
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
|
291
378
|
} else {
|
|
292
379
|
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
|
293
380
|
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
294
|
-
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
381
|
+
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
|
295
382
|
}
|
|
296
383
|
}
|
|
297
384
|
|
|
@@ -321,7 +408,9 @@ static void rope_vision_cuda(
|
|
|
321
408
|
}
|
|
322
409
|
|
|
323
410
|
template <bool forward>
|
|
324
|
-
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
|
411
|
+
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
|
412
|
+
ggml_tensor * dst,
|
|
413
|
+
const ggml_tensor * set_rows = nullptr) {
|
|
325
414
|
const ggml_tensor * src0 = dst->src[0];
|
|
326
415
|
const ggml_tensor * src1 = dst->src[1];
|
|
327
416
|
const ggml_tensor * src2 = dst->src[2];
|
|
@@ -329,12 +418,25 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
329
418
|
const float * src0_d = (const float *)src0->data;
|
|
330
419
|
const float * src1_d = (const float *)src1->data;
|
|
331
420
|
|
|
332
|
-
|
|
421
|
+
void * dst_d = dst->data;
|
|
422
|
+
const int64_t * row_indices = nullptr;
|
|
423
|
+
ggml_type dst_type = dst->type;
|
|
424
|
+
int set_rows_stride = 0;
|
|
425
|
+
|
|
426
|
+
if (set_rows != nullptr) {
|
|
427
|
+
GGML_ASSERT(forward);
|
|
428
|
+
dst_d = set_rows->data;
|
|
429
|
+
row_indices = (const int64_t *) set_rows->src[1]->data;
|
|
430
|
+
dst_type = set_rows->type;
|
|
431
|
+
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
|
|
432
|
+
}
|
|
333
433
|
cudaStream_t stream = ctx.stream();
|
|
334
434
|
|
|
335
435
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
336
436
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
337
|
-
|
|
437
|
+
// When not fused, src0 and dst types must match
|
|
438
|
+
// When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
|
|
439
|
+
GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
|
|
338
440
|
|
|
339
441
|
const int64_t ne00 = src0->ne[0]; // head dims
|
|
340
442
|
const int64_t ne01 = src0->ne[1]; // num heads
|
|
@@ -369,6 +471,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
369
471
|
|
|
370
472
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
371
473
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
474
|
+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
|
372
475
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
373
476
|
|
|
374
477
|
if (is_mrope) {
|
|
@@ -391,14 +494,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
391
494
|
|
|
392
495
|
// compute
|
|
393
496
|
if (is_neox) {
|
|
394
|
-
if (src0->type == GGML_TYPE_F32) {
|
|
395
|
-
rope_neox_cuda<forward>(
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
} else if (src0->type == GGML_TYPE_F16) {
|
|
399
|
-
rope_neox_cuda<forward>(
|
|
400
|
-
|
|
401
|
-
|
|
497
|
+
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
|
498
|
+
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
|
499
|
+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
500
|
+
freq_factors, row_indices, set_rows_stride, stream);
|
|
501
|
+
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
|
502
|
+
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
|
503
|
+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
504
|
+
freq_factors, row_indices, set_rows_stride, stream);
|
|
505
|
+
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
|
506
|
+
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
|
507
|
+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
508
|
+
freq_factors, row_indices, set_rows_stride, stream);
|
|
402
509
|
} else {
|
|
403
510
|
GGML_ABORT("fatal error");
|
|
404
511
|
}
|
|
@@ -406,11 +513,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
406
513
|
if (src0->type == GGML_TYPE_F32) {
|
|
407
514
|
rope_multi_cuda<forward>(
|
|
408
515
|
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
|
409
|
-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
516
|
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
|
410
517
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
411
518
|
rope_multi_cuda<forward>(
|
|
412
519
|
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
|
413
|
-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
520
|
+
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
|
414
521
|
} else {
|
|
415
522
|
GGML_ABORT("fatal error");
|
|
416
523
|
}
|
|
@@ -427,14 +534,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
427
534
|
GGML_ABORT("fatal error");
|
|
428
535
|
}
|
|
429
536
|
} else {
|
|
430
|
-
if (src0->type == GGML_TYPE_F32) {
|
|
431
|
-
rope_norm_cuda<forward>(
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
} else if (src0->type == GGML_TYPE_F16) {
|
|
435
|
-
rope_norm_cuda<forward>(
|
|
436
|
-
|
|
437
|
-
|
|
537
|
+
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
|
538
|
+
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
|
539
|
+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
540
|
+
freq_factors, row_indices, set_rows_stride, stream);
|
|
541
|
+
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
|
542
|
+
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
|
543
|
+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
544
|
+
freq_factors, row_indices, set_rows_stride, stream);
|
|
545
|
+
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
|
546
|
+
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
|
547
|
+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
548
|
+
freq_factors, row_indices, set_rows_stride, stream);
|
|
438
549
|
} else {
|
|
439
550
|
GGML_ABORT("fatal error");
|
|
440
551
|
}
|
|
@@ -448,3 +559,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
448
559
|
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
449
560
|
ggml_cuda_op_rope_impl<false>(ctx, dst);
|
|
450
561
|
}
|
|
562
|
+
|
|
563
|
+
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
|
|
564
|
+
ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
|
|
565
|
+
}
|
|
@@ -4,30 +4,53 @@
|
|
|
4
4
|
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
|
5
5
|
|
|
6
6
|
// Generic quantized set_rows kernel template
|
|
7
|
-
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
|
|
8
|
-
static __global__ void k_set_rows_quant(
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
7
|
+
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
|
|
8
|
+
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
|
|
9
|
+
const idx_t * __restrict__ src1,
|
|
10
|
+
block_type * __restrict__ dst,
|
|
11
|
+
const int64_t ne_total,
|
|
12
|
+
const int64_t ne10,
|
|
13
|
+
const int64_t ne11,
|
|
14
|
+
const int64_t ne12,
|
|
15
|
+
const int64_t ne13,
|
|
16
|
+
const int64_t s01,
|
|
17
|
+
const int64_t s02,
|
|
18
|
+
const int64_t s03,
|
|
19
|
+
const int64_t s10,
|
|
20
|
+
const int64_t s11,
|
|
21
|
+
const int64_t s12,
|
|
22
|
+
const int64_t s1,
|
|
23
|
+
const int64_t s2,
|
|
24
|
+
const int64_t s3,
|
|
25
|
+
const uint3 ne00,
|
|
26
|
+
const uint3 ne01,
|
|
27
|
+
const uint3 ne02,
|
|
28
|
+
const uint3 ne11_fd,
|
|
29
|
+
const uint3 ne12_fd) {
|
|
16
30
|
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
|
17
|
-
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
|
|
18
31
|
|
|
19
32
|
if (i >= ne_total) {
|
|
20
33
|
return;
|
|
21
34
|
}
|
|
22
35
|
|
|
23
36
|
const int64_t i_base = i * qk;
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
37
|
+
uint32_t tmp = (uint32_t) i_base;
|
|
38
|
+
uint2 div_mod;
|
|
39
|
+
|
|
40
|
+
div_mod = fast_div_modulo(tmp, ne00);
|
|
41
|
+
const int64_t i00 = div_mod.y;
|
|
42
|
+
tmp = div_mod.x;
|
|
28
43
|
|
|
29
|
-
|
|
30
|
-
const int64_t
|
|
44
|
+
div_mod = fast_div_modulo(tmp, ne01);
|
|
45
|
+
const int64_t i01 = div_mod.y;
|
|
46
|
+
tmp = div_mod.x;
|
|
47
|
+
|
|
48
|
+
div_mod = fast_div_modulo(tmp, ne02);
|
|
49
|
+
const int64_t i02 = div_mod.y;
|
|
50
|
+
const int64_t i03 = div_mod.x;
|
|
51
|
+
|
|
52
|
+
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
|
|
53
|
+
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
|
31
54
|
const int64_t i10 = i01;
|
|
32
55
|
|
|
33
56
|
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
|
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
|
|
|
41
64
|
quantize_func(src_block, dst_block);
|
|
42
65
|
|
|
43
66
|
GGML_UNUSED(ne10);
|
|
67
|
+
GGML_UNUSED(ne11);
|
|
68
|
+
GGML_UNUSED(ne12);
|
|
44
69
|
GGML_UNUSED(ne13);
|
|
45
70
|
}
|
|
46
71
|
|
|
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
|
|
|
71
96
|
const int64_t s2 = nb2;
|
|
72
97
|
const int64_t s3 = nb3;
|
|
73
98
|
|
|
74
|
-
if (ne_total > 0) {
|
|
99
|
+
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
|
100
|
+
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
|
101
|
+
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
|
102
|
+
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
|
103
|
+
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
|
104
|
+
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
|
105
|
+
|
|
75
106
|
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
|
|
76
|
-
src0_d, src1_d, dst_d,
|
|
77
|
-
|
|
78
|
-
ne10, ne11, ne12, ne13,
|
|
79
|
-
s01, s02, s03,
|
|
80
|
-
s10, s11, s12,
|
|
81
|
-
s1, s2, s3);
|
|
107
|
+
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
|
|
108
|
+
ne01_fd, ne02_fd, ne11_fd, ne12_fd);
|
|
82
109
|
}
|
|
83
110
|
}
|
|
84
111
|
|
|
85
|
-
template<typename src_t, typename idx_t, typename dst_t>
|
|
86
|
-
static __global__ void k_set_rows(
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
112
|
+
template <typename src_t, typename idx_t, typename dst_t>
|
|
113
|
+
static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
|
114
|
+
const idx_t * __restrict__ src1,
|
|
115
|
+
dst_t * __restrict__ dst,
|
|
116
|
+
const int64_t ne_total,
|
|
117
|
+
const int64_t ne10,
|
|
118
|
+
const int64_t ne11,
|
|
119
|
+
const int64_t ne12,
|
|
120
|
+
const int64_t ne13,
|
|
121
|
+
const int64_t s01,
|
|
122
|
+
const int64_t s02,
|
|
123
|
+
const int64_t s03,
|
|
124
|
+
const int64_t s10,
|
|
125
|
+
const int64_t s11,
|
|
126
|
+
const int64_t s12,
|
|
127
|
+
const int64_t s1,
|
|
128
|
+
const int64_t s2,
|
|
129
|
+
const int64_t s3,
|
|
130
|
+
const uint3 ne00,
|
|
131
|
+
const uint3 ne01,
|
|
132
|
+
const uint3 ne02,
|
|
133
|
+
const uint3 ne11_fd,
|
|
134
|
+
const uint3 ne12_fd) {
|
|
94
135
|
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
|
95
|
-
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
|
96
136
|
|
|
97
137
|
if (i >= ne_total) {
|
|
98
138
|
return;
|
|
99
139
|
}
|
|
100
140
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
141
|
+
uint32_t tmp = (uint32_t) i;
|
|
142
|
+
uint2 div_mod;
|
|
143
|
+
|
|
144
|
+
div_mod = fast_div_modulo(tmp, ne00);
|
|
145
|
+
const int64_t i00 = div_mod.y;
|
|
146
|
+
tmp = div_mod.x;
|
|
105
147
|
|
|
106
|
-
|
|
107
|
-
const int64_t
|
|
148
|
+
div_mod = fast_div_modulo(tmp, ne01);
|
|
149
|
+
const int64_t i01 = div_mod.y;
|
|
150
|
+
tmp = div_mod.x;
|
|
151
|
+
|
|
152
|
+
div_mod = fast_div_modulo(tmp, ne02);
|
|
153
|
+
const int64_t i02 = div_mod.y;
|
|
154
|
+
const int64_t i03 = div_mod.x;
|
|
155
|
+
|
|
156
|
+
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
|
|
157
|
+
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
|
108
158
|
const int64_t i10 = i01;
|
|
109
159
|
|
|
110
160
|
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
|
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
|
|
|
115
165
|
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
|
116
166
|
|
|
117
167
|
GGML_UNUSED(ne10);
|
|
168
|
+
GGML_UNUSED(ne11);
|
|
169
|
+
GGML_UNUSED(ne12);
|
|
118
170
|
GGML_UNUSED(ne13);
|
|
119
171
|
}
|
|
120
172
|
|
|
@@ -144,14 +196,16 @@ static void set_rows_cuda(
|
|
|
144
196
|
const int64_t s2 = nb2/sizeof(dst_t);
|
|
145
197
|
const int64_t s3 = nb3/sizeof(dst_t);
|
|
146
198
|
|
|
147
|
-
if (ne_total > 0) {
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
199
|
+
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
|
200
|
+
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
|
201
|
+
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
|
202
|
+
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
|
203
|
+
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
|
204
|
+
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
|
205
|
+
|
|
206
|
+
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
|
|
207
|
+
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
|
|
208
|
+
ne11_fd, ne12_fd);
|
|
155
209
|
}
|
|
156
210
|
}
|
|
157
211
|
|