whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -2,10 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
#define CUDA_CPY_BLOCK_SIZE 64
|
|
4
4
|
|
|
5
|
-
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1
|
|
5
|
+
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
|
6
6
|
|
|
7
7
|
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
8
|
-
|
|
9
|
-
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
|
10
|
-
|
|
11
|
-
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
#include <algorithm>
|
|
2
|
+
#include "cumsum.cuh"
|
|
3
|
+
#include "convert.cuh"
|
|
4
|
+
#include "ggml-cuda/common.cuh"
|
|
5
|
+
#include "ggml.h"
|
|
6
|
+
|
|
7
|
+
#ifdef GGML_CUDA_USE_CUB
|
|
8
|
+
# include <cub/cub.cuh>
|
|
9
|
+
#endif // GGML_CUDA_USE_CUB
|
|
10
|
+
|
|
11
|
+
template<typename T, int BLOCK_SIZE>
|
|
12
|
+
static __global__ void cumsum_cub_kernel(
|
|
13
|
+
const T * __restrict__ src,
|
|
14
|
+
T * __restrict__ dst,
|
|
15
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
|
16
|
+
const int64_t s01, const int64_t s02, const int64_t s03,
|
|
17
|
+
const int64_t s1, const int64_t s2, const int64_t s3) {
|
|
18
|
+
#ifdef GGML_CUDA_USE_CUB
|
|
19
|
+
using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
|
|
20
|
+
|
|
21
|
+
__shared__ typename BlockScanT::TempStorage temp_storage;
|
|
22
|
+
__shared__ T block_carry;
|
|
23
|
+
|
|
24
|
+
const int tid = threadIdx.x;
|
|
25
|
+
constexpr int UNROLL_FACTOR = 4;
|
|
26
|
+
constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
|
|
27
|
+
|
|
28
|
+
const int64_t i1 = blockIdx.x;
|
|
29
|
+
const int64_t i2 = blockIdx.y;
|
|
30
|
+
const int64_t i3 = blockIdx.z;
|
|
31
|
+
|
|
32
|
+
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
|
|
33
|
+
return;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
|
|
37
|
+
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
|
|
38
|
+
|
|
39
|
+
if (tid == 0) {
|
|
40
|
+
block_carry = 0;
|
|
41
|
+
}
|
|
42
|
+
__syncthreads();
|
|
43
|
+
|
|
44
|
+
for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
|
|
45
|
+
T items[UNROLL_FACTOR];
|
|
46
|
+
T thread_sum = T(0);
|
|
47
|
+
|
|
48
|
+
#pragma unroll
|
|
49
|
+
for (int i = 0; i < UNROLL_FACTOR; i++) {
|
|
50
|
+
int64_t idx = start + tid * UNROLL_FACTOR + i;
|
|
51
|
+
T val = (idx < ne00) ? src_row[idx] : T(0);
|
|
52
|
+
thread_sum += val;
|
|
53
|
+
items[i] = thread_sum;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
// Block-wide scan on thread sums
|
|
57
|
+
T thread_prefix;
|
|
58
|
+
T block_total;
|
|
59
|
+
BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
|
|
60
|
+
__syncthreads();
|
|
61
|
+
|
|
62
|
+
// Add offset to each item and store
|
|
63
|
+
T thread_offset = thread_prefix - thread_sum + block_carry;
|
|
64
|
+
#pragma unroll
|
|
65
|
+
for (int i = 0; i < UNROLL_FACTOR; i++) {
|
|
66
|
+
int64_t idx = start + tid * UNROLL_FACTOR + i;
|
|
67
|
+
if (idx < ne00) {
|
|
68
|
+
dst_row[idx] = items[i] + thread_offset;
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
__syncthreads();
|
|
73
|
+
|
|
74
|
+
// Update carry for next tile
|
|
75
|
+
if (tid == 0) {
|
|
76
|
+
block_carry += block_total;
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
#else
|
|
80
|
+
NO_DEVICE_CODE;
|
|
81
|
+
#endif // GGML_CUDA_USE_CUB
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// Fallback kernel implementation
|
|
85
|
+
template<typename T>
|
|
86
|
+
static __global__ void cumsum_kernel(
|
|
87
|
+
const T * src, T * dst,
|
|
88
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
|
89
|
+
const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
|
|
90
|
+
const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
|
|
91
|
+
|
|
92
|
+
GGML_UNUSED_VARS(s00, s0);
|
|
93
|
+
|
|
94
|
+
const int tid = threadIdx.x;
|
|
95
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
96
|
+
const int lane = tid % warp_size;
|
|
97
|
+
const int warp = tid / warp_size;
|
|
98
|
+
const int warps_per_block = blockDim.x / warp_size;
|
|
99
|
+
|
|
100
|
+
extern __shared__ float smem[];
|
|
101
|
+
float * s_vals = smem;
|
|
102
|
+
float * s_warp_sums = smem + blockDim.x;
|
|
103
|
+
float * s_carry = smem + blockDim.x + warps_per_block;
|
|
104
|
+
float * s_chunk_total = s_carry + 1;
|
|
105
|
+
|
|
106
|
+
// Initialize carry
|
|
107
|
+
if (tid == 0) {
|
|
108
|
+
*s_carry = 0.0f;
|
|
109
|
+
}
|
|
110
|
+
__syncthreads();
|
|
111
|
+
|
|
112
|
+
const int64_t i3 = blockIdx.z;
|
|
113
|
+
const int64_t i2 = blockIdx.y;
|
|
114
|
+
const int64_t i1 = blockIdx.x;
|
|
115
|
+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
|
116
|
+
return;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
|
|
120
|
+
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
|
|
121
|
+
|
|
122
|
+
// register blocking: process 4 elements per thread to hide latency
|
|
123
|
+
// and reduce synchronization overhead
|
|
124
|
+
constexpr int num_unroll = 4;
|
|
125
|
+
T temp[num_unroll];
|
|
126
|
+
|
|
127
|
+
for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
|
|
128
|
+
int64_t idx = i + tid * num_unroll;
|
|
129
|
+
|
|
130
|
+
// thread local sequential scan
|
|
131
|
+
temp[0] = (idx < ne00 ? src_row[idx] : T(0));
|
|
132
|
+
#pragma unroll
|
|
133
|
+
for (int64_t j = 1; j < num_unroll; j++) {
|
|
134
|
+
temp[j] = temp[j - 1];
|
|
135
|
+
if (idx + j < ne00) {
|
|
136
|
+
temp[j] += src_row[idx + j];
|
|
137
|
+
} else {
|
|
138
|
+
temp[j] += 0;
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
// last emenent is sum of all values assigned to thread
|
|
143
|
+
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
|
|
144
|
+
|
|
145
|
+
// Warp inclusive scan
|
|
146
|
+
val = warp_prefix_inclusive_sum<T, warp_size>(val);
|
|
147
|
+
s_vals[tid] = val;
|
|
148
|
+
|
|
149
|
+
if (lane == warp_size - 1) {
|
|
150
|
+
s_warp_sums[warp] = val;
|
|
151
|
+
}
|
|
152
|
+
__syncthreads();
|
|
153
|
+
|
|
154
|
+
// Exclusive scan of warp sums (warp 0 only)
|
|
155
|
+
if (warp == 0) {
|
|
156
|
+
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
|
|
157
|
+
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
|
|
158
|
+
if (tid < warps_per_block) {
|
|
159
|
+
s_warp_sums[tid] = inc - w; // exclusive sum
|
|
160
|
+
}
|
|
161
|
+
if (tid == warps_per_block - 1) {
|
|
162
|
+
*s_chunk_total = inc; // total sum of this chunk
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
__syncthreads();
|
|
166
|
+
|
|
167
|
+
// write back results
|
|
168
|
+
float carry = *s_carry;
|
|
169
|
+
// calculate sum offset for this thread
|
|
170
|
+
float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
|
|
171
|
+
|
|
172
|
+
#pragma unroll
|
|
173
|
+
for (int32_t j = 0; j < num_unroll; j++) {
|
|
174
|
+
if (idx + j < ne00) {
|
|
175
|
+
dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
__syncthreads();
|
|
180
|
+
|
|
181
|
+
// Update carry for next chunk
|
|
182
|
+
if (tid == 0) {
|
|
183
|
+
*s_carry += *s_chunk_total;
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
#ifdef GGML_CUDA_USE_CUB
|
|
189
|
+
template <typename T>
|
|
190
|
+
static void cumsum_cub(ggml_cuda_pool & pool,
|
|
191
|
+
const T * src,
|
|
192
|
+
T * dst,
|
|
193
|
+
int64_t ne,
|
|
194
|
+
cudaStream_t stream) {
|
|
195
|
+
size_t tmp_size = 0;
|
|
196
|
+
|
|
197
|
+
// Query how much temp storage CUDA UnBound (CUB) needs
|
|
198
|
+
cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
|
|
199
|
+
tmp_size, // reference to size (will be set by CUB)
|
|
200
|
+
src, // input pointer
|
|
201
|
+
dst, // output pointer
|
|
202
|
+
ne, // number of elements
|
|
203
|
+
stream // CUDA stream to use
|
|
204
|
+
);
|
|
205
|
+
|
|
206
|
+
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
|
207
|
+
|
|
208
|
+
// Perform the inclusive scan
|
|
209
|
+
cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
|
|
210
|
+
}
|
|
211
|
+
#endif // GGML_CUDA_USE_CUB
|
|
212
|
+
|
|
213
|
+
template<typename T>
|
|
214
|
+
static void cumsum_cuda(
|
|
215
|
+
[[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
|
|
216
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
|
217
|
+
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
|
|
218
|
+
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
|
|
219
|
+
cudaStream_t stream) {
|
|
220
|
+
|
|
221
|
+
const size_t type_size = sizeof(T);
|
|
222
|
+
bool use_cub = false;
|
|
223
|
+
#ifdef GGML_CUDA_USE_CUB
|
|
224
|
+
// Check if we can use CUB (data must be contiguous along innermost dimension)
|
|
225
|
+
const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
|
|
226
|
+
|
|
227
|
+
if (is_contiguous) {
|
|
228
|
+
use_cub = true;
|
|
229
|
+
const int64_t nrows = ne01 * ne02 * ne03;
|
|
230
|
+
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
|
|
231
|
+
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
|
|
232
|
+
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
|
|
233
|
+
for (int i=0; i<nrows; i++) {
|
|
234
|
+
cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
|
|
235
|
+
}
|
|
236
|
+
return;
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
#endif // GGML_CUDA_USE_CUB
|
|
240
|
+
dim3 grid_dims(ne01, ne02, ne03);
|
|
241
|
+
const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
|
|
242
|
+
const int warp_size = info.warp_size;
|
|
243
|
+
const int num_warps = (ne00 + warp_size - 1) / warp_size;
|
|
244
|
+
int block_size = num_warps * warp_size;
|
|
245
|
+
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
|
|
246
|
+
dim3 block_dims(block_size, 1, 1);
|
|
247
|
+
const int warps_per_block = block_size / warp_size;
|
|
248
|
+
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
|
|
249
|
+
|
|
250
|
+
if (use_cub && ne00 >= 1024) {
|
|
251
|
+
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
|
|
252
|
+
src, dst,
|
|
253
|
+
ne00, ne01, ne02, ne03,
|
|
254
|
+
nb01 / type_size, nb02 / type_size, nb03 / type_size,
|
|
255
|
+
nb1 / type_size, nb2 / type_size, nb3 / type_size
|
|
256
|
+
);
|
|
257
|
+
} else {
|
|
258
|
+
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
|
|
259
|
+
src, dst,
|
|
260
|
+
ne00, ne01, ne02, ne03,
|
|
261
|
+
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
|
|
262
|
+
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
|
|
263
|
+
);
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
268
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
269
|
+
cudaStream_t stream = ctx.stream();
|
|
270
|
+
|
|
271
|
+
GGML_ASSERT(src0->type == dst->type);
|
|
272
|
+
switch(src0->type) {
|
|
273
|
+
case GGML_TYPE_F32:
|
|
274
|
+
{
|
|
275
|
+
cumsum_cuda(
|
|
276
|
+
ctx, (const float *)src0->data, (float *)dst->data,
|
|
277
|
+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
|
278
|
+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
|
279
|
+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
|
280
|
+
stream
|
|
281
|
+
);
|
|
282
|
+
} break;
|
|
283
|
+
// We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
|
|
284
|
+
/*case GGML_TYPE_F16:
|
|
285
|
+
{
|
|
286
|
+
cumsum_cuda(
|
|
287
|
+
(const half *)src0->data, (half *)dst->data,
|
|
288
|
+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
|
289
|
+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
|
290
|
+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
|
291
|
+
stream
|
|
292
|
+
);
|
|
293
|
+
} break;
|
|
294
|
+
case GGML_TYPE_BF16:
|
|
295
|
+
{
|
|
296
|
+
cumsum_cuda(
|
|
297
|
+
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
|
|
298
|
+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
|
299
|
+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
|
300
|
+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
|
301
|
+
stream
|
|
302
|
+
);
|
|
303
|
+
} break;*/
|
|
304
|
+
default:
|
|
305
|
+
GGML_ABORT("fatal error");
|
|
306
|
+
}
|
|
307
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
#include "convert.cuh"
|
|
2
|
+
#include "diag.cuh"
|
|
3
|
+
#include "ggml.h"
|
|
4
|
+
|
|
5
|
+
template <typename T>
|
|
6
|
+
static __global__ void diag_kernel(T * __restrict__ dst,
|
|
7
|
+
const T * __restrict__ src,
|
|
8
|
+
const int64_t ne0,
|
|
9
|
+
const int64_t ne1,
|
|
10
|
+
const int64_t ne2,
|
|
11
|
+
const int64_t ne3,
|
|
12
|
+
const int64_t total_elements) {
|
|
13
|
+
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
14
|
+
|
|
15
|
+
if (global_idx >= total_elements) {
|
|
16
|
+
return;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
const int64_t i0 = global_idx % ne0;
|
|
20
|
+
const int64_t i1 = (global_idx / ne0) % ne1;
|
|
21
|
+
const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2;
|
|
22
|
+
const int64_t i3 = global_idx / (ne0 * ne1 * ne2);
|
|
23
|
+
|
|
24
|
+
const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
|
|
25
|
+
|
|
26
|
+
if (i0 == i1) {
|
|
27
|
+
const int64_t batch_idx = i3 * ne2 + i2;
|
|
28
|
+
const int64_t src_idx = batch_idx * ne0 + i0;
|
|
29
|
+
dst[dst_idx] = src[src_idx];
|
|
30
|
+
} else {
|
|
31
|
+
dst[dst_idx] = ggml_cuda_cast<T>(0);
|
|
32
|
+
}
|
|
33
|
+
GGML_UNUSED_VARS(ne3);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
37
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
38
|
+
|
|
39
|
+
void * dst_d = dst->data;
|
|
40
|
+
const void * src0_d = src0->data;
|
|
41
|
+
|
|
42
|
+
cudaStream_t stream = ctx.stream();
|
|
43
|
+
|
|
44
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
45
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
46
|
+
|
|
47
|
+
const int64_t ne00 = src0->ne[0];
|
|
48
|
+
const int64_t ne01 = src0->ne[1];
|
|
49
|
+
const int64_t ne02 = src0->ne[2];
|
|
50
|
+
const int64_t ne03 = src0->ne[3];
|
|
51
|
+
|
|
52
|
+
const int64_t ne0 = dst->ne[0];
|
|
53
|
+
const int64_t ne1 = dst->ne[1];
|
|
54
|
+
const int64_t ne2 = dst->ne[2];
|
|
55
|
+
const int64_t ne3 = dst->ne[3];
|
|
56
|
+
|
|
57
|
+
GGML_ASSERT(ne00 == ne0);
|
|
58
|
+
GGML_ASSERT(ne01 == 1);
|
|
59
|
+
GGML_ASSERT(ne02 == ne2);
|
|
60
|
+
GGML_ASSERT(ne03 == ne3);
|
|
61
|
+
|
|
62
|
+
const int64_t n_elems = ggml_nelements(dst);
|
|
63
|
+
const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE;
|
|
64
|
+
|
|
65
|
+
switch (dst->type) {
|
|
66
|
+
case GGML_TYPE_F32:
|
|
67
|
+
diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0,
|
|
68
|
+
ne1, ne2, ne3, n_elems);
|
|
69
|
+
break;
|
|
70
|
+
case GGML_TYPE_F16:
|
|
71
|
+
diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0,
|
|
72
|
+
ne1, ne2, ne3, n_elems);
|
|
73
|
+
break;
|
|
74
|
+
default:
|
|
75
|
+
GGML_ABORT("unsupported type");
|
|
76
|
+
}
|
|
77
|
+
}
|
|
@@ -10,6 +10,14 @@
|
|
|
10
10
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
|
11
11
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
|
12
12
|
|
|
13
|
+
// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
|
|
14
|
+
// by the VKQ accumulators is effectively being shifted up by a factor of 2.
|
|
15
|
+
// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
|
|
16
|
+
// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
|
|
17
|
+
// Still, the value range should be shifted as much as necessary but as little as possible.
|
|
18
|
+
// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
|
|
19
|
+
#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
|
|
20
|
+
|
|
13
21
|
typedef void (* fattn_kernel_t)(
|
|
14
22
|
const char * __restrict__ Q,
|
|
15
23
|
const char * __restrict__ K,
|
|
@@ -25,7 +33,7 @@ typedef void (* fattn_kernel_t)(
|
|
|
25
33
|
const float m1,
|
|
26
34
|
const uint32_t n_head_log2,
|
|
27
35
|
const float logit_softcap,
|
|
28
|
-
const int32_t ne00, const
|
|
36
|
+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
|
29
37
|
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
30
38
|
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
31
39
|
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
@@ -55,11 +63,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
|
|
55
63
|
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
|
56
64
|
#pragma unroll
|
|
57
65
|
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
|
58
|
-
#ifdef
|
|
66
|
+
#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
59
67
|
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
|
60
68
|
#else
|
|
61
69
|
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
|
62
|
-
#endif //
|
|
70
|
+
#endif // V_DOT2_F32_F16_AVAILABLE
|
|
63
71
|
}
|
|
64
72
|
}
|
|
65
73
|
|
|
@@ -621,7 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
|
|
|
621
629
|
template<int D, int ncols1, int ncols2> // D == head size
|
|
622
630
|
__launch_bounds__(D, 1)
|
|
623
631
|
static __global__ void flash_attn_stream_k_fixup(
|
|
624
|
-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11
|
|
632
|
+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
|
|
633
|
+
const int nbatch_fa) {
|
|
625
634
|
constexpr int ncols = ncols1*ncols2;
|
|
626
635
|
|
|
627
636
|
const int bidx0 = blockIdx.x;
|
|
@@ -632,11 +641,11 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
632
641
|
|
|
633
642
|
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
|
634
643
|
|
|
635
|
-
const int iter_k = ne11 /
|
|
636
|
-
const int iter_j = (ne01 + (ncols1
|
|
644
|
+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
645
|
+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
637
646
|
|
|
638
|
-
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
639
|
-
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
647
|
+
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
648
|
+
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
640
649
|
|
|
641
650
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
642
651
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
@@ -672,7 +681,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
672
681
|
int bidx = bidx0 - 1;
|
|
673
682
|
int kbc_stop = kbc0;
|
|
674
683
|
while(true) {
|
|
675
|
-
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
684
|
+
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
676
685
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
677
686
|
bidx--;
|
|
678
687
|
kbc_stop = kbc;
|
|
@@ -765,7 +774,7 @@ static __global__ void flash_attn_combine_results(
|
|
|
765
774
|
template <int DV, int ncols1, int ncols2>
|
|
766
775
|
void launch_fattn(
|
|
767
776
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
|
768
|
-
const int
|
|
777
|
+
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
|
769
778
|
) {
|
|
770
779
|
constexpr int ncols = ncols1 * ncols2;
|
|
771
780
|
|
|
@@ -790,10 +799,6 @@ void launch_fattn(
|
|
|
790
799
|
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
|
791
800
|
|
|
792
801
|
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
|
793
|
-
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
|
794
|
-
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
|
795
|
-
|
|
796
|
-
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
|
797
802
|
|
|
798
803
|
ggml_cuda_pool & pool = ctx.pool();
|
|
799
804
|
cudaStream_t main_stream = ctx.stream();
|
|
@@ -878,7 +883,7 @@ void launch_fattn(
|
|
|
878
883
|
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
|
879
884
|
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
|
880
885
|
// multiple sequences of possibly different lengths.
|
|
881
|
-
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
|
886
|
+
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
|
882
887
|
const int s31 = mask->nb[1] / sizeof(half2);
|
|
883
888
|
const int s33 = mask->nb[3] / sizeof(half2);
|
|
884
889
|
|
|
@@ -897,6 +902,7 @@ void launch_fattn(
|
|
|
897
902
|
const dim3 block_dim(warp_size, nwarps, 1);
|
|
898
903
|
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
|
899
904
|
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
|
905
|
+
GGML_ASSERT(max_blocks_per_sm > 0);
|
|
900
906
|
int parallel_blocks = max_blocks_per_sm;
|
|
901
907
|
|
|
902
908
|
dim3 blocks_num;
|
|
@@ -914,10 +920,11 @@ void launch_fattn(
|
|
|
914
920
|
blocks_num.y = 1;
|
|
915
921
|
blocks_num.z = 1;
|
|
916
922
|
|
|
917
|
-
|
|
923
|
+
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
|
924
|
+
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
|
|
925
|
+
}
|
|
918
926
|
} else {
|
|
919
|
-
|
|
920
|
-
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
|
927
|
+
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
|
|
921
928
|
|
|
922
929
|
// parallel_blocks must not be larger than what the tensor size allows:
|
|
923
930
|
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
|
@@ -946,7 +953,7 @@ void launch_fattn(
|
|
|
946
953
|
|
|
947
954
|
blocks_num.x = ntiles_x;
|
|
948
955
|
blocks_num.y = parallel_blocks;
|
|
949
|
-
blocks_num.z = Q->ne[2]*Q->ne[3];
|
|
956
|
+
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
|
|
950
957
|
|
|
951
958
|
if (parallel_blocks > 1) {
|
|
952
959
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
|
@@ -972,6 +979,9 @@ void launch_fattn(
|
|
|
972
979
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
973
980
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
974
981
|
|
|
982
|
+
// TODO other tensor dimensions after removal of WMMA kernel:
|
|
983
|
+
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
|
|
984
|
+
|
|
975
985
|
GGML_ASSERT(block_dim.x % warp_size == 0);
|
|
976
986
|
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
|
977
987
|
(const char *) Q->data,
|
|
@@ -982,7 +992,7 @@ void launch_fattn(
|
|
|
982
992
|
KV_max.ptr,
|
|
983
993
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
|
984
994
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
985
|
-
Q->ne[0],
|
|
995
|
+
Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
|
986
996
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
|
987
997
|
nb21, nb22, nb23,
|
|
988
998
|
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
|
@@ -997,7 +1007,7 @@ void launch_fattn(
|
|
|
997
1007
|
|
|
998
1008
|
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
|
999
1009
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
1000
|
-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
|
1010
|
+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
|
|
1001
1011
|
}
|
|
1002
1012
|
} else if (parallel_blocks > 1) {
|
|
1003
1013
|
const dim3 block_dim_combine(DV, 1, 1);
|