whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
2
|
+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
|
3
|
+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
|
4
|
+
|
|
5
|
+
#include "types.glsl"
|
|
6
|
+
|
|
7
|
+
// Each iqs value maps to a 32-bit integer
|
|
8
|
+
|
|
9
|
+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
|
10
|
+
// 2-byte loads for Q4_0 blocks (18 bytes)
|
|
11
|
+
// 4-byte loads for Q4_1 blocks (20 bytes)
|
|
12
|
+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|
13
|
+
#ifdef DATA_A_Q4_0
|
|
14
|
+
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
|
15
|
+
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
|
16
|
+
|
|
17
|
+
if (iqs == 0) {
|
|
18
|
+
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
|
19
|
+
}
|
|
20
|
+
#else // DATA_A_Q4_1
|
|
21
|
+
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
|
22
|
+
|
|
23
|
+
if (iqs == 0) {
|
|
24
|
+
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
|
25
|
+
}
|
|
26
|
+
#endif
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
|
30
|
+
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
|
31
|
+
|
|
32
|
+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
|
33
|
+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|
38
|
+
int32_t q_sum = 0;
|
|
39
|
+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
|
40
|
+
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
|
41
|
+
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
|
|
42
|
+
(vui >> 4) & 0x0F0F0F0F);
|
|
43
|
+
|
|
44
|
+
const int32_t qs_b0 = cache_b.qs[iqs];
|
|
45
|
+
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
|
46
|
+
|
|
47
|
+
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
|
|
48
|
+
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
#ifdef DATA_A_Q4_0
|
|
52
|
+
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));
|
|
53
|
+
#else // DATA_A_Q4_1
|
|
54
|
+
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
|
55
|
+
#endif
|
|
56
|
+
}
|
|
57
|
+
#endif
|
|
58
|
+
|
|
59
|
+
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
|
60
|
+
// 2-byte loads for Q5_0 blocks (22 bytes)
|
|
61
|
+
// 4-byte loads for Q5_1 blocks (24 bytes)
|
|
62
|
+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|
63
|
+
#ifdef DATA_A_Q5_0
|
|
64
|
+
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
|
65
|
+
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
|
66
|
+
|
|
67
|
+
if (iqs == 0) {
|
|
68
|
+
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
|
69
|
+
buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
|
|
70
|
+
}
|
|
71
|
+
#else // DATA_A_Q5_1
|
|
72
|
+
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
|
73
|
+
|
|
74
|
+
if (iqs == 0) {
|
|
75
|
+
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
|
76
|
+
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
|
|
77
|
+
}
|
|
78
|
+
#endif
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
|
82
|
+
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
|
83
|
+
cache_a[reg_ib].qh = buf_a[buf_ib].qh;
|
|
84
|
+
|
|
85
|
+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
|
86
|
+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|
91
|
+
int32_t q_sum = 0;
|
|
92
|
+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
|
93
|
+
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
|
94
|
+
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
|
|
95
|
+
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
|
|
96
|
+
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
|
97
|
+
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
|
98
|
+
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
|
99
|
+
|
|
100
|
+
const int32_t qs_b0 = cache_b.qs[iqs];
|
|
101
|
+
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
|
102
|
+
|
|
103
|
+
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
|
|
104
|
+
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
#ifdef DATA_A_Q5_0
|
|
108
|
+
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));
|
|
109
|
+
#else // DATA_A_Q5_1
|
|
110
|
+
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
|
111
|
+
#endif
|
|
112
|
+
}
|
|
113
|
+
#endif
|
|
114
|
+
|
|
115
|
+
#if defined(DATA_A_Q8_0)
|
|
116
|
+
// 2-byte loads for Q8_0 blocks (34 bytes)
|
|
117
|
+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|
118
|
+
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
|
|
119
|
+
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
|
120
|
+
|
|
121
|
+
if (iqs == 0) {
|
|
122
|
+
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
|
127
|
+
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
|
128
|
+
|
|
129
|
+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
|
130
|
+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|
135
|
+
int32_t q_sum = 0;
|
|
136
|
+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
|
137
|
+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
|
138
|
+
const int32_t qs_b = cache_b.qs[iqs];
|
|
139
|
+
|
|
140
|
+
q_sum += dotPacked4x8EXT(qs_a, qs_b);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));
|
|
144
|
+
}
|
|
145
|
+
#endif
|
|
146
|
+
|
|
147
|
+
#if defined(DATA_A_MXFP4)
|
|
148
|
+
// 1-byte loads for mxfp4 blocks (17 bytes)
|
|
149
|
+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|
150
|
+
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
|
151
|
+
data_a[ib].qs[iqs * 4 + 1],
|
|
152
|
+
data_a[ib].qs[iqs * 4 + 2],
|
|
153
|
+
data_a[ib].qs[iqs * 4 + 3]));
|
|
154
|
+
|
|
155
|
+
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
|
156
|
+
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
|
157
|
+
|
|
158
|
+
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
|
|
159
|
+
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
|
|
160
|
+
|
|
161
|
+
if (iqs == 0) {
|
|
162
|
+
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
|
167
|
+
cache_a[reg_ib].d = buf_a[buf_ib].d;
|
|
168
|
+
|
|
169
|
+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
|
170
|
+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|
175
|
+
int32_t q_sum = 0;
|
|
176
|
+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
|
177
|
+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
|
178
|
+
|
|
179
|
+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));
|
|
183
|
+
}
|
|
184
|
+
#endif
|
|
185
|
+
|
|
186
|
+
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
|
|
187
|
+
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
|
|
188
|
+
#if defined(DATA_A_Q2_K)
|
|
189
|
+
// 4-byte loads for Q2_K blocks (84 bytes)
|
|
190
|
+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|
191
|
+
const uint ib_k = ib / 8;
|
|
192
|
+
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
|
193
|
+
|
|
194
|
+
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
|
195
|
+
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
|
196
|
+
|
|
197
|
+
// Repack 4x4 quants into one int
|
|
198
|
+
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
|
|
199
|
+
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
|
|
200
|
+
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
|
|
201
|
+
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
|
|
202
|
+
|
|
203
|
+
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
|
|
204
|
+
|
|
205
|
+
if (iqs == 0) {
|
|
206
|
+
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
|
207
|
+
buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
|
212
|
+
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
|
213
|
+
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
|
|
214
|
+
|
|
215
|
+
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
|
|
216
|
+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|
221
|
+
int32_t sum_d = 0;
|
|
222
|
+
int32_t sum_m = 0;
|
|
223
|
+
|
|
224
|
+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
|
225
|
+
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
|
|
226
|
+
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
|
|
227
|
+
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
|
|
228
|
+
|
|
229
|
+
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
|
|
230
|
+
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));
|
|
234
|
+
}
|
|
235
|
+
#endif
|
|
236
|
+
|
|
237
|
+
#if defined(DATA_A_Q3_K)
|
|
238
|
+
// 2-byte loads for Q3_K blocks (110 bytes)
|
|
239
|
+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|
240
|
+
const uint ib_k = ib / 8;
|
|
241
|
+
const uint hm_idx = iqs * QUANT_R_MMQ;
|
|
242
|
+
const uint iqs_k = (ib % 8) * 8 + hm_idx;
|
|
243
|
+
|
|
244
|
+
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
|
245
|
+
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
|
246
|
+
const uint hm_shift = iqs_k / 8;
|
|
247
|
+
|
|
248
|
+
// Repack 2x4 quants into one int
|
|
249
|
+
// Add the 3rd bit instead of subtracting it to allow packing the quants
|
|
250
|
+
// vec4 for unpack8 used due to #12147
|
|
251
|
+
const i8vec2 vals00 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
|
252
|
+
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
|
253
|
+
const i8vec2 vals01 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
|
254
|
+
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
|
255
|
+
const i8vec2 vals10 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
|
256
|
+
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
|
257
|
+
const i8vec2 vals11 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
|
258
|
+
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
|
259
|
+
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
|
|
260
|
+
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
|
|
261
|
+
|
|
262
|
+
if (iqs == 0) {
|
|
263
|
+
const uint is = iqs_k / 4;
|
|
264
|
+
const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
|
265
|
+
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
|
|
266
|
+
|
|
267
|
+
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
|
272
|
+
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
|
273
|
+
|
|
274
|
+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
|
275
|
+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|
280
|
+
float result = 0.0;
|
|
281
|
+
int32_t q_sum = 0;
|
|
282
|
+
|
|
283
|
+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
|
284
|
+
// Subtract 4 from the quants to correct the 3rd bit offset
|
|
285
|
+
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
|
286
|
+
|
|
287
|
+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
|
288
|
+
}
|
|
289
|
+
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
|
290
|
+
q_sum = 0;
|
|
291
|
+
|
|
292
|
+
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
|
293
|
+
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
|
294
|
+
|
|
295
|
+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
|
296
|
+
}
|
|
297
|
+
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
|
298
|
+
|
|
299
|
+
return ACC_TYPE(float(cache_b.ds.x) * result);
|
|
300
|
+
}
|
|
301
|
+
#endif
|
|
302
|
+
|
|
303
|
+
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
|
304
|
+
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
|
305
|
+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|
306
|
+
const uint ib_k = ib / 8;
|
|
307
|
+
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
|
308
|
+
|
|
309
|
+
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
|
|
310
|
+
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
|
|
311
|
+
|
|
312
|
+
// Repack 2x4 quants into one int
|
|
313
|
+
#if defined(DATA_A_Q4_K)
|
|
314
|
+
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
|
|
315
|
+
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
|
|
316
|
+
|
|
317
|
+
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
|
|
318
|
+
#else // defined(DATA_A_Q5_K)
|
|
319
|
+
const uint qh_idx = iqs * QUANT_R_MMQ;
|
|
320
|
+
const uint qh_shift = iqs_k / 8;
|
|
321
|
+
|
|
322
|
+
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
|
|
323
|
+
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
|
|
324
|
+
#endif
|
|
325
|
+
|
|
326
|
+
if (iqs == 0) {
|
|
327
|
+
// Scale index
|
|
328
|
+
const uint is = iqs_k / 8;
|
|
329
|
+
u8vec2 scale_dm;
|
|
330
|
+
if (is < 4) {
|
|
331
|
+
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
|
|
332
|
+
} else {
|
|
333
|
+
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
|
|
334
|
+
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
|
338
|
+
}
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
|
342
|
+
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
|
343
|
+
|
|
344
|
+
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
|
|
345
|
+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|
350
|
+
int32_t q_sum = 0;
|
|
351
|
+
|
|
352
|
+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
|
353
|
+
#if defined(DATA_A_Q4_K)
|
|
354
|
+
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
|
|
355
|
+
#else // defined(DATA_A_Q5_K)
|
|
356
|
+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
|
357
|
+
#endif
|
|
358
|
+
|
|
359
|
+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
|
363
|
+
}
|
|
364
|
+
#endif
|
|
365
|
+
|
|
366
|
+
#if defined(DATA_A_Q6_K)
|
|
367
|
+
// 2-byte loads for Q6_K blocks (210 bytes)
|
|
368
|
+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|
369
|
+
const uint ib_k = ib / 8;
|
|
370
|
+
const uint iqs_k = (ib % 8) * 8 + iqs;
|
|
371
|
+
|
|
372
|
+
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
|
|
373
|
+
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
|
|
374
|
+
|
|
375
|
+
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
|
376
|
+
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
|
377
|
+
|
|
378
|
+
const i8vec2 vals00 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))).xy |
|
|
379
|
+
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
|
|
380
|
+
const i8vec2 vals01 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))).xy |
|
|
381
|
+
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
|
|
382
|
+
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
|
|
383
|
+
|
|
384
|
+
if (iqs == 0) {
|
|
385
|
+
const uint is = iqs_k / 4;
|
|
386
|
+
const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
|
|
387
|
+
|
|
388
|
+
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
|
|
389
|
+
}
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
|
393
|
+
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
|
394
|
+
|
|
395
|
+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
|
396
|
+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|
401
|
+
float result = 0.0;
|
|
402
|
+
int32_t q_sum = 0;
|
|
403
|
+
|
|
404
|
+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
|
405
|
+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
|
406
|
+
|
|
407
|
+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
|
408
|
+
}
|
|
409
|
+
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
|
410
|
+
q_sum = 0;
|
|
411
|
+
|
|
412
|
+
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
|
413
|
+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
|
414
|
+
|
|
415
|
+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
|
416
|
+
}
|
|
417
|
+
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
|
418
|
+
|
|
419
|
+
return ACC_TYPE(float(cache_b.ds.x) * result);
|
|
420
|
+
}
|
|
421
|
+
#endif
|
|
422
|
+
|
|
423
|
+
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
|
|
424
|
+
if (is_in_bounds) {
|
|
425
|
+
const uint ib_outer = ib / 4;
|
|
426
|
+
const uint ib_inner = ib % 4;
|
|
427
|
+
|
|
428
|
+
if (iqs == 0) {
|
|
429
|
+
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
|
433
|
+
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
|
434
|
+
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
|
435
|
+
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
|
436
|
+
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
|
437
|
+
} else {
|
|
438
|
+
if (iqs == 0) {
|
|
439
|
+
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
buf_b[buf_ib].qs[iqs * 4 ] = 0;
|
|
443
|
+
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
|
|
444
|
+
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
|
|
445
|
+
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
void block_b_to_registers(const uint ib) {
|
|
450
|
+
cache_b.ds = buf_b[ib].ds;
|
|
451
|
+
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
|
452
|
+
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
|
453
|
+
}
|
|
454
|
+
}
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
#if defined(DATA_A_Q4_0)
|
|
2
|
+
#define QUANT_R_MMQ 2
|
|
3
|
+
struct block_a_cache {
|
|
4
|
+
uint32_t qs[16/4];
|
|
5
|
+
FLOAT_TYPE dm;
|
|
6
|
+
};
|
|
7
|
+
#elif defined(DATA_A_Q4_1)
|
|
8
|
+
#define QUANT_R_MMQ 2
|
|
9
|
+
struct block_a_cache {
|
|
10
|
+
uint32_t qs[16/4];
|
|
11
|
+
FLOAT_TYPE_VEC2 dm;
|
|
12
|
+
};
|
|
13
|
+
#elif defined(DATA_A_Q5_0)
|
|
14
|
+
#define QUANT_R_MMQ 2
|
|
15
|
+
struct block_a_cache {
|
|
16
|
+
uint32_t qs[16/4];
|
|
17
|
+
uint32_t qh;
|
|
18
|
+
FLOAT_TYPE dm;
|
|
19
|
+
};
|
|
20
|
+
#elif defined(DATA_A_Q5_1)
|
|
21
|
+
#define QUANT_R_MMQ 2
|
|
22
|
+
struct block_a_cache {
|
|
23
|
+
uint32_t qs[16/4];
|
|
24
|
+
uint32_t qh;
|
|
25
|
+
FLOAT_TYPE_VEC2 dm;
|
|
26
|
+
};
|
|
27
|
+
#elif defined(DATA_A_Q8_0)
|
|
28
|
+
#define QUANT_R_MMQ 1
|
|
29
|
+
// AMD likes 4, Intel likes 1 and Nvidia likes 2
|
|
30
|
+
// #define BK_STEP 1
|
|
31
|
+
struct block_a_cache {
|
|
32
|
+
int32_t qs[32/4];
|
|
33
|
+
FLOAT_TYPE dm;
|
|
34
|
+
};
|
|
35
|
+
#elif defined(DATA_A_MXFP4)
|
|
36
|
+
#define QUANT_R_MMQ 2
|
|
37
|
+
struct block_a_cache {
|
|
38
|
+
int32_t qs[8];
|
|
39
|
+
FLOAT_TYPE d;
|
|
40
|
+
};
|
|
41
|
+
#elif defined(DATA_A_Q2_K)
|
|
42
|
+
#define QUANT_R_MMQ 4
|
|
43
|
+
struct block_a_cache {
|
|
44
|
+
uint32_t qs[2];
|
|
45
|
+
u8vec2 scales;
|
|
46
|
+
FLOAT_TYPE_VEC2 dm;
|
|
47
|
+
};
|
|
48
|
+
#elif defined(DATA_A_Q3_K)
|
|
49
|
+
#define QUANT_R_MMQ 2
|
|
50
|
+
struct block_a_cache {
|
|
51
|
+
uint32_t qs[4];
|
|
52
|
+
FLOAT_TYPE_VEC2 d_scales;
|
|
53
|
+
};
|
|
54
|
+
#elif defined(DATA_A_Q4_K)
|
|
55
|
+
#define QUANT_R_MMQ 2
|
|
56
|
+
struct block_a_cache {
|
|
57
|
+
uint32_t qs[4];
|
|
58
|
+
FLOAT_TYPE_VEC2 dm;
|
|
59
|
+
};
|
|
60
|
+
#elif defined(DATA_A_Q5_K)
|
|
61
|
+
#define QUANT_R_MMQ 1
|
|
62
|
+
struct block_a_cache {
|
|
63
|
+
int32_t qs[8];
|
|
64
|
+
FLOAT_TYPE_VEC2 dm;
|
|
65
|
+
};
|
|
66
|
+
#elif defined(DATA_A_Q6_K)
|
|
67
|
+
#define QUANT_R_MMQ 1
|
|
68
|
+
struct block_a_cache {
|
|
69
|
+
int32_t qs[8];
|
|
70
|
+
FLOAT_TYPE_VEC2 d_scales;
|
|
71
|
+
};
|
|
72
|
+
#endif
|
|
73
|
+
|
|
74
|
+
struct block_b_cache
|
|
75
|
+
{
|
|
76
|
+
int32_t qs[8];
|
|
77
|
+
FLOAT_TYPE_VEC2 ds;
|
|
78
|
+
};
|
|
@@ -8,9 +8,9 @@
|
|
|
8
8
|
#extension GL_KHR_shader_subgroup_basic : enable
|
|
9
9
|
#endif
|
|
10
10
|
|
|
11
|
-
#include "rte.
|
|
12
|
-
#include "types.
|
|
13
|
-
#include "utils.
|
|
11
|
+
#include "rte.glsl"
|
|
12
|
+
#include "types.glsl"
|
|
13
|
+
#include "utils.glsl"
|
|
14
14
|
|
|
15
15
|
layout (push_constant) uniform parameter2
|
|
16
16
|
{
|
|
@@ -23,16 +23,100 @@ layout (push_constant) uniform parameter2
|
|
|
23
23
|
uint rms_partials;
|
|
24
24
|
} p;
|
|
25
25
|
|
|
26
|
-
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
layout (binding =
|
|
30
|
-
layout (binding =
|
|
31
|
-
|
|
32
|
-
layout (binding =
|
|
26
|
+
// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
|
27
|
+
layout (binding = 0) buffer A0 {A_TYPE data_a[];} a0;
|
|
28
|
+
layout (binding = 1) buffer A1 {A_TYPE data_a[];} a1;
|
|
29
|
+
layout (binding = 2) buffer A2 {A_TYPE data_a[];} a2;
|
|
30
|
+
layout (binding = 3) buffer A3 {A_TYPE data_a[];} a3;
|
|
31
|
+
layout (binding = 4) buffer A4 {A_TYPE data_a[];} a4;
|
|
32
|
+
layout (binding = 5) buffer A5 {A_TYPE data_a[];} a5;
|
|
33
|
+
layout (binding = 6) buffer A6 {A_TYPE data_a[];} a6;
|
|
34
|
+
layout (binding = 7) buffer A7 {A_TYPE data_a[];} a7;
|
|
35
|
+
layout (binding = 8) buffer A8 {A_TYPE data_a[];} a8;
|
|
36
|
+
layout (binding = 9) buffer A9 {A_TYPE data_a[];} a9;
|
|
37
|
+
layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
|
|
38
|
+
layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
|
|
39
|
+
layout (binding = 0) buffer D0 {D_TYPE data_d[];} d0;
|
|
40
|
+
layout (binding = 1) buffer D1 {D_TYPE data_d[];} d1;
|
|
41
|
+
layout (binding = 2) buffer D2 {D_TYPE data_d[];} d2;
|
|
42
|
+
layout (binding = 3) buffer D3 {D_TYPE data_d[];} d3;
|
|
43
|
+
layout (binding = 4) buffer D4 {D_TYPE data_d[];} d4;
|
|
44
|
+
layout (binding = 5) buffer D5 {D_TYPE data_d[];} d5;
|
|
45
|
+
layout (binding = 6) buffer D6 {D_TYPE data_d[];} d6;
|
|
46
|
+
layout (binding = 7) buffer D7 {D_TYPE data_d[];} d7;
|
|
47
|
+
layout (binding = 8) buffer D8 {D_TYPE data_d[];} d8;
|
|
48
|
+
layout (binding = 9) buffer D9 {D_TYPE data_d[];} d9;
|
|
49
|
+
layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
|
|
50
|
+
layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
|
|
51
|
+
layout (binding = 0, std430) buffer PartialBuf0 {float partial_sums[];} partials0;
|
|
52
|
+
layout (binding = 1, std430) buffer PartialBuf1 {float partial_sums[];} partials1;
|
|
53
|
+
layout (binding = 2, std430) buffer PartialBuf2 {float partial_sums[];} partials2;
|
|
54
|
+
layout (binding = 3, std430) buffer PartialBuf3 {float partial_sums[];} partials3;
|
|
55
|
+
layout (binding = 4, std430) buffer PartialBuf4 {float partial_sums[];} partials4;
|
|
56
|
+
layout (binding = 5, std430) buffer PartialBuf5 {float partial_sums[];} partials5;
|
|
57
|
+
layout (binding = 6, std430) buffer PartialBuf6 {float partial_sums[];} partials6;
|
|
58
|
+
layout (binding = 7, std430) buffer PartialBuf7 {float partial_sums[];} partials7;
|
|
59
|
+
layout (binding = 8, std430) buffer PartialBuf8 {float partial_sums[];} partials8;
|
|
60
|
+
layout (binding = 9, std430) buffer PartialBuf9 {float partial_sums[];} partials9;
|
|
61
|
+
layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
|
|
62
|
+
layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;
|
|
33
63
|
|
|
34
64
|
layout(constant_id = 0) const uint num_srcs = 2;
|
|
35
65
|
|
|
66
|
+
FLOAT_TYPE load_a(uint b, uint i) {
|
|
67
|
+
switch (b) {
|
|
68
|
+
case 0: return FLOAT_TYPE(a0.data_a[i]);
|
|
69
|
+
case 1: return FLOAT_TYPE(a1.data_a[i]);
|
|
70
|
+
case 2: return FLOAT_TYPE(a2.data_a[i]);
|
|
71
|
+
case 3: return FLOAT_TYPE(a3.data_a[i]);
|
|
72
|
+
case 4: return FLOAT_TYPE(a4.data_a[i]);
|
|
73
|
+
case 5: return FLOAT_TYPE(a5.data_a[i]);
|
|
74
|
+
case 6: return FLOAT_TYPE(a6.data_a[i]);
|
|
75
|
+
case 7: return FLOAT_TYPE(a7.data_a[i]);
|
|
76
|
+
case 8: return FLOAT_TYPE(a8.data_a[i]);
|
|
77
|
+
case 9: return FLOAT_TYPE(a9.data_a[i]);
|
|
78
|
+
case 10: return FLOAT_TYPE(a10.data_a[i]);
|
|
79
|
+
case 11: return FLOAT_TYPE(a11.data_a[i]);
|
|
80
|
+
default: return FLOAT_TYPE(0);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
void store_d(uint b, uint i, FLOAT_TYPE v) {
|
|
85
|
+
switch (b) {
|
|
86
|
+
case 0: d0.data_d[i] = D_TYPE(v); break;
|
|
87
|
+
case 1: d1.data_d[i] = D_TYPE(v); break;
|
|
88
|
+
case 2: d2.data_d[i] = D_TYPE(v); break;
|
|
89
|
+
case 3: d3.data_d[i] = D_TYPE(v); break;
|
|
90
|
+
case 4: d4.data_d[i] = D_TYPE(v); break;
|
|
91
|
+
case 5: d5.data_d[i] = D_TYPE(v); break;
|
|
92
|
+
case 6: d6.data_d[i] = D_TYPE(v); break;
|
|
93
|
+
case 7: d7.data_d[i] = D_TYPE(v); break;
|
|
94
|
+
case 8: d8.data_d[i] = D_TYPE(v); break;
|
|
95
|
+
case 9: d9.data_d[i] = D_TYPE(v); break;
|
|
96
|
+
case 10: d10.data_d[i] = D_TYPE(v); break;
|
|
97
|
+
case 11: d11.data_d[i] = D_TYPE(v); break;
|
|
98
|
+
default: break;
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
void store_partial(uint b, uint i, float v) {
|
|
103
|
+
switch (b) {
|
|
104
|
+
case 0: partials0.partial_sums[i] = v; break;
|
|
105
|
+
case 1: partials1.partial_sums[i] = v; break;
|
|
106
|
+
case 2: partials2.partial_sums[i] = v; break;
|
|
107
|
+
case 3: partials3.partial_sums[i] = v; break;
|
|
108
|
+
case 4: partials4.partial_sums[i] = v; break;
|
|
109
|
+
case 5: partials5.partial_sums[i] = v; break;
|
|
110
|
+
case 6: partials6.partial_sums[i] = v; break;
|
|
111
|
+
case 7: partials7.partial_sums[i] = v; break;
|
|
112
|
+
case 8: partials8.partial_sums[i] = v; break;
|
|
113
|
+
case 9: partials9.partial_sums[i] = v; break;
|
|
114
|
+
case 10: partials10.partial_sums[i] = v; break;
|
|
115
|
+
case 11: partials11.partial_sums[i] = v; break;
|
|
116
|
+
default: break;
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
36
120
|
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
|
|
37
121
|
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
|
|
38
122
|
}
|
|
@@ -78,10 +162,10 @@ void main() {
|
|
|
78
162
|
|
|
79
163
|
FLOAT_TYPE sum = FLOAT_TYPE(0);
|
|
80
164
|
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
|
81
|
-
sum +=
|
|
165
|
+
sum += load_a(s, src_idx(s, i00, i01, i02, i03));
|
|
82
166
|
}
|
|
83
167
|
sum_sq += sum*sum;
|
|
84
|
-
|
|
168
|
+
store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);
|
|
85
169
|
|
|
86
170
|
idx += num_threads;
|
|
87
171
|
}
|
|
@@ -104,7 +188,7 @@ void main() {
|
|
|
104
188
|
}
|
|
105
189
|
|
|
106
190
|
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
|
107
|
-
|
|
191
|
+
store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);
|
|
108
192
|
}
|
|
109
193
|
}
|
|
110
194
|
#endif
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#include "generic_head.glsl"
|
|
4
|
+
#include "types.glsl"
|
|
5
|
+
|
|
6
|
+
#extension GL_EXT_control_flow_attributes : enable
|
|
7
|
+
|
|
8
|
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
9
|
+
|
|
10
|
+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
|
11
|
+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|
12
|
+
|
|
13
|
+
void main() {
|
|
14
|
+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
|
15
|
+
|
|
16
|
+
if (i >= p.KX) {
|
|
17
|
+
return;
|
|
18
|
+
}
|
|
19
|
+
data_d[i] = D_TYPE(-float(data_a[i]));
|
|
20
|
+
}
|