whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -69,6 +69,10 @@
|
|
|
69
69
|
#define VECTOR_REGISTERS 16
|
|
70
70
|
#endif
|
|
71
71
|
|
|
72
|
+
#if defined(__riscv_v_intrinsic)
|
|
73
|
+
#define LMUL 4
|
|
74
|
+
#endif
|
|
75
|
+
|
|
72
76
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
|
73
77
|
|
|
74
78
|
namespace {
|
|
@@ -117,8 +121,7 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
|
|
117
121
|
#endif
|
|
118
122
|
|
|
119
123
|
#if defined(__MMA__)
|
|
120
|
-
|
|
121
|
-
typedef __vector_quad acc_t;
|
|
124
|
+
#include "sgemm-ppc.h"
|
|
122
125
|
#endif
|
|
123
126
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
124
127
|
// VECTORIZED FUSED MULTIPLY ADD
|
|
@@ -176,6 +179,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
|
|
176
179
|
}
|
|
177
180
|
#endif
|
|
178
181
|
|
|
182
|
+
#if defined(__riscv_zvfh)
|
|
183
|
+
template <>
|
|
184
|
+
inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
|
185
|
+
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
186
|
+
}
|
|
187
|
+
inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
|
188
|
+
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
189
|
+
}
|
|
190
|
+
inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
|
191
|
+
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
192
|
+
}
|
|
193
|
+
inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
|
194
|
+
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
195
|
+
}
|
|
196
|
+
inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
|
197
|
+
return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
198
|
+
}
|
|
199
|
+
inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
|
200
|
+
return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
201
|
+
}
|
|
202
|
+
inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
|
203
|
+
return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
204
|
+
}
|
|
205
|
+
inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
|
206
|
+
return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
207
|
+
}
|
|
208
|
+
#endif
|
|
209
|
+
|
|
210
|
+
#if defined(__riscv_zvfbfwma)
|
|
211
|
+
inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
|
212
|
+
return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
213
|
+
}
|
|
214
|
+
inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
|
215
|
+
return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
216
|
+
}
|
|
217
|
+
inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
|
218
|
+
return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
219
|
+
}
|
|
220
|
+
#endif
|
|
221
|
+
|
|
179
222
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
180
223
|
// VECTORIZED HORIZONTAL SUM
|
|
181
224
|
|
|
@@ -228,6 +271,25 @@ inline float hsum(__m512 x) {
|
|
|
228
271
|
}
|
|
229
272
|
#endif // __AVX512F__
|
|
230
273
|
|
|
274
|
+
#if defined(__riscv_zvfh)
|
|
275
|
+
inline float hsum(vfloat32m1_t x) {
|
|
276
|
+
return __riscv_vfmv_f_s_f32m1_f32(
|
|
277
|
+
__riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
|
|
278
|
+
}
|
|
279
|
+
inline float hsum(vfloat32m2_t x) {
|
|
280
|
+
return __riscv_vfmv_f_s_f32m1_f32(
|
|
281
|
+
__riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
|
|
282
|
+
}
|
|
283
|
+
inline float hsum(vfloat32m4_t x) {
|
|
284
|
+
return __riscv_vfmv_f_s_f32m1_f32(
|
|
285
|
+
__riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
|
|
286
|
+
}
|
|
287
|
+
inline float hsum(vfloat32m8_t x) {
|
|
288
|
+
return __riscv_vfmv_f_s_f32m1_f32(
|
|
289
|
+
__riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
|
|
290
|
+
}
|
|
291
|
+
#endif
|
|
292
|
+
|
|
231
293
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
232
294
|
// VECTORIZED MEMORY LOADING
|
|
233
295
|
|
|
@@ -316,6 +378,88 @@ template <> inline __m256bh load(const float *p) {
|
|
|
316
378
|
}
|
|
317
379
|
#endif
|
|
318
380
|
|
|
381
|
+
#if defined(__riscv_zvfh)
|
|
382
|
+
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
|
|
383
|
+
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
|
|
384
|
+
}
|
|
385
|
+
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
|
|
386
|
+
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
|
|
387
|
+
}
|
|
388
|
+
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
|
389
|
+
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
|
|
390
|
+
}
|
|
391
|
+
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
|
392
|
+
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
|
393
|
+
}
|
|
394
|
+
template <> inline vfloat32m1_t load(const float *p) {
|
|
395
|
+
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
|
|
396
|
+
}
|
|
397
|
+
template <> inline vfloat32m2_t load(const float *p) {
|
|
398
|
+
return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
|
|
399
|
+
}
|
|
400
|
+
template <> inline vfloat32m4_t load(const float *p) {
|
|
401
|
+
return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
|
|
402
|
+
}
|
|
403
|
+
template <> inline vfloat32m8_t load(const float *p) {
|
|
404
|
+
return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
|
|
405
|
+
}
|
|
406
|
+
#endif
|
|
407
|
+
|
|
408
|
+
#if defined(__riscv_zvfbfwma)
|
|
409
|
+
template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
|
|
410
|
+
return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
|
|
411
|
+
}
|
|
412
|
+
template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
|
|
413
|
+
return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
|
|
414
|
+
}
|
|
415
|
+
template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
|
|
416
|
+
return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
|
|
417
|
+
}
|
|
418
|
+
#endif
|
|
419
|
+
|
|
420
|
+
#if defined(__riscv_zvfh)
|
|
421
|
+
template <typename T> T set_zero();
|
|
422
|
+
|
|
423
|
+
template <> inline vfloat16mf2_t set_zero() {
|
|
424
|
+
return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
|
|
425
|
+
}
|
|
426
|
+
template <> inline vfloat16m1_t set_zero() {
|
|
427
|
+
return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
|
|
428
|
+
}
|
|
429
|
+
template <> inline vfloat16m2_t set_zero() {
|
|
430
|
+
return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
|
|
431
|
+
}
|
|
432
|
+
template <> inline vfloat16m4_t set_zero() {
|
|
433
|
+
return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
|
|
434
|
+
}
|
|
435
|
+
template <> inline vfloat32m1_t set_zero() {
|
|
436
|
+
return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
|
|
437
|
+
}
|
|
438
|
+
template <> inline vfloat32m2_t set_zero() {
|
|
439
|
+
return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
|
|
440
|
+
}
|
|
441
|
+
template <> inline vfloat32m4_t set_zero() {
|
|
442
|
+
return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
|
|
443
|
+
}
|
|
444
|
+
template <> inline vfloat32m8_t set_zero() {
|
|
445
|
+
return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
|
|
446
|
+
}
|
|
447
|
+
#endif
|
|
448
|
+
|
|
449
|
+
#if defined(__riscv_v_intrinsic)
|
|
450
|
+
template <typename T> size_t vlmax() {
|
|
451
|
+
if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
|
452
|
+
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
|
453
|
+
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
|
454
|
+
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
|
455
|
+
else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
|
456
|
+
else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
|
|
457
|
+
else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
|
|
458
|
+
else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
|
|
459
|
+
return 0;
|
|
460
|
+
}
|
|
461
|
+
#endif
|
|
462
|
+
|
|
319
463
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
320
464
|
// FLOATING POINT MATRIX MULTIPLICATION
|
|
321
465
|
|
|
@@ -489,6 +633,573 @@ class tinyBLAS {
|
|
|
489
633
|
const int64_t ldc;
|
|
490
634
|
};
|
|
491
635
|
|
|
636
|
+
#if defined(__riscv_v_intrinsic)
|
|
637
|
+
template <typename D, typename V, typename TA, typename TB, typename TC>
|
|
638
|
+
class tinyBLAS_RVV {
|
|
639
|
+
public:
|
|
640
|
+
tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
|
|
641
|
+
const TA *A, int64_t lda,
|
|
642
|
+
const TB *B, int64_t ldb,
|
|
643
|
+
TC *C, int64_t ldc)
|
|
644
|
+
: params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
bool matmul(int64_t m, int64_t n) {
|
|
648
|
+
if (k % vlmax<V>() != 0) {
|
|
649
|
+
return false;
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
#if LMUL == 1
|
|
653
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
654
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
655
|
+
mnpack<4, 6, 4>(m, n, SIZE_N, 12);
|
|
656
|
+
return true;
|
|
657
|
+
}
|
|
658
|
+
if (m % 8 == 0 ) {
|
|
659
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
660
|
+
mnpack<4, 6, 2>(m, n, SIZE_N, 12);
|
|
661
|
+
return true;
|
|
662
|
+
}
|
|
663
|
+
if (m % 4 == 0) {
|
|
664
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
665
|
+
mnpack<4, 6, 1>(m, n, SIZE_N, 12);
|
|
666
|
+
return true;
|
|
667
|
+
}
|
|
668
|
+
#elif LMUL == 2
|
|
669
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
670
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
671
|
+
mnpack<4, 3, 4>(m, n, SIZE_N, 24);
|
|
672
|
+
return true;
|
|
673
|
+
}
|
|
674
|
+
if (m % 8 == 0 ) {
|
|
675
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
676
|
+
mnpack<4, 3, 2>(m, n, SIZE_N, 24);
|
|
677
|
+
return true;
|
|
678
|
+
}
|
|
679
|
+
if (m % 4 == 0) {
|
|
680
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
681
|
+
mnpack<4, 3, 1>(m, n, SIZE_N, 24);
|
|
682
|
+
return true;
|
|
683
|
+
}
|
|
684
|
+
#else // LMUL = 4
|
|
685
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
686
|
+
const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
687
|
+
mnpack<2, 2, 8>(m, n, SIZE_N, 36);
|
|
688
|
+
return true;
|
|
689
|
+
}
|
|
690
|
+
if (m % 8 == 0 ) {
|
|
691
|
+
const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
692
|
+
mnpack<2, 2, 4>(m, n, SIZE_N, 36);
|
|
693
|
+
return true;
|
|
694
|
+
}
|
|
695
|
+
if (m % 4 == 0) {
|
|
696
|
+
const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
697
|
+
mnpack<2, 2, 2>(m, n, SIZE_N, 36);
|
|
698
|
+
return true;
|
|
699
|
+
}
|
|
700
|
+
#endif
|
|
701
|
+
return false;
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
private:
|
|
705
|
+
template<int RM, int RN, int BM>
|
|
706
|
+
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
|
|
707
|
+
if (SIZE_N == RN) {
|
|
708
|
+
return gemm<RM, RN, BM>(m, n, BN);
|
|
709
|
+
}
|
|
710
|
+
if constexpr (RN > 1) {
|
|
711
|
+
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
|
712
|
+
} else {
|
|
713
|
+
GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
|
|
714
|
+
GGML_ASSERT(false); // we have miss something.
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
|
|
719
|
+
size_t vl = vlmax<V>();
|
|
720
|
+
D Cv00 = set_zero<D>();
|
|
721
|
+
D Cv01 = set_zero<D>();
|
|
722
|
+
D Cv02 = set_zero<D>();
|
|
723
|
+
D Cv03 = set_zero<D>();
|
|
724
|
+
D Cv10 = set_zero<D>();
|
|
725
|
+
D Cv11 = set_zero<D>();
|
|
726
|
+
D Cv12 = set_zero<D>();
|
|
727
|
+
D Cv13 = set_zero<D>();
|
|
728
|
+
D Cv20 = set_zero<D>();
|
|
729
|
+
D Cv21 = set_zero<D>();
|
|
730
|
+
D Cv22 = set_zero<D>();
|
|
731
|
+
D Cv23 = set_zero<D>();
|
|
732
|
+
D Cv30 = set_zero<D>();
|
|
733
|
+
D Cv31 = set_zero<D>();
|
|
734
|
+
D Cv32 = set_zero<D>();
|
|
735
|
+
D Cv33 = set_zero<D>();
|
|
736
|
+
D Cv40 = set_zero<D>();
|
|
737
|
+
D Cv41 = set_zero<D>();
|
|
738
|
+
D Cv42 = set_zero<D>();
|
|
739
|
+
D Cv43 = set_zero<D>();
|
|
740
|
+
D Cv50 = set_zero<D>();
|
|
741
|
+
D Cv51 = set_zero<D>();
|
|
742
|
+
D Cv52 = set_zero<D>();
|
|
743
|
+
D Cv53 = set_zero<D>();
|
|
744
|
+
|
|
745
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
746
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
747
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
748
|
+
V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
749
|
+
V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
750
|
+
V Bv4 = load<V>(B + ldb * (jj + 4) + l);
|
|
751
|
+
V Bv5 = load<V>(B + ldb * (jj + 5) + l);
|
|
752
|
+
|
|
753
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
754
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
755
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
756
|
+
Cv20 = madd(Av0, Bv2, Cv20);
|
|
757
|
+
Cv30 = madd(Av0, Bv3, Cv30);
|
|
758
|
+
Cv40 = madd(Av0, Bv4, Cv40);
|
|
759
|
+
Cv50 = madd(Av0, Bv5, Cv50);
|
|
760
|
+
|
|
761
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
762
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
763
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
764
|
+
Cv21 = madd(Av1, Bv2, Cv21);
|
|
765
|
+
Cv31 = madd(Av1, Bv3, Cv31);
|
|
766
|
+
Cv41 = madd(Av1, Bv4, Cv41);
|
|
767
|
+
Cv51 = madd(Av1, Bv5, Cv51);
|
|
768
|
+
|
|
769
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
770
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
771
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
772
|
+
Cv22 = madd(Av2, Bv2, Cv22);
|
|
773
|
+
Cv32 = madd(Av2, Bv3, Cv32);
|
|
774
|
+
Cv42 = madd(Av2, Bv4, Cv42);
|
|
775
|
+
Cv52 = madd(Av2, Bv5, Cv52);
|
|
776
|
+
|
|
777
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
778
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
779
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
780
|
+
Cv23 = madd(Av3, Bv2, Cv23);
|
|
781
|
+
Cv33 = madd(Av3, Bv3, Cv33);
|
|
782
|
+
Cv43 = madd(Av3, Bv4, Cv43);
|
|
783
|
+
Cv53 = madd(Av3, Bv5, Cv53);
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
787
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
788
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
789
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
790
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
791
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
792
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
793
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
794
|
+
C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
795
|
+
C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
796
|
+
C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
797
|
+
C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
798
|
+
C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
799
|
+
C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
800
|
+
C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
801
|
+
C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
802
|
+
C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
|
|
803
|
+
C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
|
|
804
|
+
C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
|
|
805
|
+
C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
|
|
806
|
+
C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
|
|
807
|
+
C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
|
|
808
|
+
C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
|
|
809
|
+
C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
|
|
813
|
+
size_t vl = vlmax<V>();
|
|
814
|
+
D Cv00 = set_zero<D>();
|
|
815
|
+
D Cv01 = set_zero<D>();
|
|
816
|
+
D Cv02 = set_zero<D>();
|
|
817
|
+
D Cv03 = set_zero<D>();
|
|
818
|
+
D Cv10 = set_zero<D>();
|
|
819
|
+
D Cv11 = set_zero<D>();
|
|
820
|
+
D Cv12 = set_zero<D>();
|
|
821
|
+
D Cv13 = set_zero<D>();
|
|
822
|
+
D Cv20 = set_zero<D>();
|
|
823
|
+
D Cv21 = set_zero<D>();
|
|
824
|
+
D Cv22 = set_zero<D>();
|
|
825
|
+
D Cv23 = set_zero<D>();
|
|
826
|
+
D Cv30 = set_zero<D>();
|
|
827
|
+
D Cv31 = set_zero<D>();
|
|
828
|
+
D Cv32 = set_zero<D>();
|
|
829
|
+
D Cv33 = set_zero<D>();
|
|
830
|
+
D Cv40 = set_zero<D>();
|
|
831
|
+
D Cv41 = set_zero<D>();
|
|
832
|
+
D Cv42 = set_zero<D>();
|
|
833
|
+
D Cv43 = set_zero<D>();
|
|
834
|
+
|
|
835
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
836
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
837
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
838
|
+
V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
839
|
+
V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
840
|
+
V Bv4 = load<V>(B + ldb * (jj + 4) + l);
|
|
841
|
+
|
|
842
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
843
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
844
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
845
|
+
Cv20 = madd(Av0, Bv2, Cv20);
|
|
846
|
+
Cv30 = madd(Av0, Bv3, Cv30);
|
|
847
|
+
Cv40 = madd(Av0, Bv4, Cv40);
|
|
848
|
+
|
|
849
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
850
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
851
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
852
|
+
Cv21 = madd(Av1, Bv2, Cv21);
|
|
853
|
+
Cv31 = madd(Av1, Bv3, Cv31);
|
|
854
|
+
Cv41 = madd(Av1, Bv4, Cv41);
|
|
855
|
+
|
|
856
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
857
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
858
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
859
|
+
Cv22 = madd(Av2, Bv2, Cv22);
|
|
860
|
+
Cv32 = madd(Av2, Bv3, Cv32);
|
|
861
|
+
Cv42 = madd(Av2, Bv4, Cv42);
|
|
862
|
+
|
|
863
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
864
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
865
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
866
|
+
Cv23 = madd(Av3, Bv2, Cv23);
|
|
867
|
+
Cv33 = madd(Av3, Bv3, Cv33);
|
|
868
|
+
Cv43 = madd(Av3, Bv4, Cv43);
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
872
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
873
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
874
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
875
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
876
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
877
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
878
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
879
|
+
C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
880
|
+
C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
881
|
+
C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
882
|
+
C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
883
|
+
C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
884
|
+
C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
885
|
+
C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
886
|
+
C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
887
|
+
C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
|
|
888
|
+
C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
|
|
889
|
+
C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
|
|
890
|
+
C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
|
|
894
|
+
size_t vl = vlmax<V>();
|
|
895
|
+
D Cv00 = set_zero<D>();
|
|
896
|
+
D Cv01 = set_zero<D>();
|
|
897
|
+
D Cv02 = set_zero<D>();
|
|
898
|
+
D Cv03 = set_zero<D>();
|
|
899
|
+
D Cv10 = set_zero<D>();
|
|
900
|
+
D Cv11 = set_zero<D>();
|
|
901
|
+
D Cv12 = set_zero<D>();
|
|
902
|
+
D Cv13 = set_zero<D>();
|
|
903
|
+
D Cv20 = set_zero<D>();
|
|
904
|
+
D Cv21 = set_zero<D>();
|
|
905
|
+
D Cv22 = set_zero<D>();
|
|
906
|
+
D Cv23 = set_zero<D>();
|
|
907
|
+
D Cv30 = set_zero<D>();
|
|
908
|
+
D Cv31 = set_zero<D>();
|
|
909
|
+
D Cv32 = set_zero<D>();
|
|
910
|
+
D Cv33 = set_zero<D>();
|
|
911
|
+
|
|
912
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
913
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
914
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
915
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
916
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
917
|
+
|
|
918
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
919
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
920
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
921
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
922
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
923
|
+
|
|
924
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
925
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
926
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
927
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
928
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
929
|
+
|
|
930
|
+
V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
931
|
+
Cv20 = madd(Av0, Bv2, Cv20);
|
|
932
|
+
Cv21 = madd(Av1, Bv2, Cv21);
|
|
933
|
+
Cv22 = madd(Av2, Bv2, Cv22);
|
|
934
|
+
Cv23 = madd(Av3, Bv2, Cv23);
|
|
935
|
+
|
|
936
|
+
V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
937
|
+
Cv30 = madd(Av0, Bv3, Cv30);
|
|
938
|
+
Cv31 = madd(Av1, Bv3, Cv31);
|
|
939
|
+
Cv32 = madd(Av2, Bv3, Cv32);
|
|
940
|
+
Cv33 = madd(Av3, Bv3, Cv33);
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
944
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
945
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
946
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
947
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
948
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
949
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
950
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
951
|
+
C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
952
|
+
C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
953
|
+
C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
954
|
+
C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
955
|
+
C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
956
|
+
C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
957
|
+
C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
958
|
+
C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
|
|
962
|
+
size_t vl = vlmax<V>();
|
|
963
|
+
D Cv00 = set_zero<D>();
|
|
964
|
+
D Cv01 = set_zero<D>();
|
|
965
|
+
D Cv02 = set_zero<D>();
|
|
966
|
+
D Cv03 = set_zero<D>();
|
|
967
|
+
D Cv10 = set_zero<D>();
|
|
968
|
+
D Cv11 = set_zero<D>();
|
|
969
|
+
D Cv12 = set_zero<D>();
|
|
970
|
+
D Cv13 = set_zero<D>();
|
|
971
|
+
D Cv20 = set_zero<D>();
|
|
972
|
+
D Cv21 = set_zero<D>();
|
|
973
|
+
D Cv22 = set_zero<D>();
|
|
974
|
+
D Cv23 = set_zero<D>();
|
|
975
|
+
|
|
976
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
977
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
978
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
979
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
980
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
981
|
+
|
|
982
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
983
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
984
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
985
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
986
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
987
|
+
|
|
988
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
989
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
990
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
991
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
992
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
993
|
+
|
|
994
|
+
V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
995
|
+
Cv20 = madd(Av0, Bv2, Cv20);
|
|
996
|
+
Cv21 = madd(Av1, Bv2, Cv21);
|
|
997
|
+
Cv22 = madd(Av2, Bv2, Cv22);
|
|
998
|
+
Cv23 = madd(Av3, Bv2, Cv23);
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1002
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1003
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
1004
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
1005
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
1006
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
1007
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
1008
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
1009
|
+
C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
1010
|
+
C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
1011
|
+
C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
1012
|
+
C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
1013
|
+
}
|
|
1014
|
+
|
|
1015
|
+
inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
|
|
1016
|
+
size_t vl = vlmax<V>();
|
|
1017
|
+
D Cv00 = set_zero<D>();
|
|
1018
|
+
D Cv01 = set_zero<D>();
|
|
1019
|
+
D Cv02 = set_zero<D>();
|
|
1020
|
+
D Cv03 = set_zero<D>();
|
|
1021
|
+
D Cv10 = set_zero<D>();
|
|
1022
|
+
D Cv11 = set_zero<D>();
|
|
1023
|
+
D Cv12 = set_zero<D>();
|
|
1024
|
+
D Cv13 = set_zero<D>();
|
|
1025
|
+
|
|
1026
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
1027
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
1028
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
1029
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
1030
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
1031
|
+
|
|
1032
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
1033
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
1034
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
1035
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
1036
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
1037
|
+
|
|
1038
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
1039
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
1040
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
1041
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
1042
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
1043
|
+
}
|
|
1044
|
+
|
|
1045
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1046
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1047
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
1048
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
1049
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
1050
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
1051
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
1052
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
1053
|
+
}
|
|
1054
|
+
|
|
1055
|
+
inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
|
|
1056
|
+
size_t vl = vlmax<V>();
|
|
1057
|
+
D Cv00 = set_zero<D>();
|
|
1058
|
+
D Cv01 = set_zero<D>();
|
|
1059
|
+
D Cv02 = set_zero<D>();
|
|
1060
|
+
D Cv03 = set_zero<D>();
|
|
1061
|
+
|
|
1062
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
1063
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
1064
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
1065
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
1066
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
1067
|
+
|
|
1068
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
1069
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
1070
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
1071
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
1072
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1076
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1077
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
1078
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
|
|
1082
|
+
size_t vl = vlmax<V>();
|
|
1083
|
+
D Cv00 = set_zero<D>();
|
|
1084
|
+
D Cv01 = set_zero<D>();
|
|
1085
|
+
D Cv10 = set_zero<D>();
|
|
1086
|
+
D Cv11 = set_zero<D>();
|
|
1087
|
+
|
|
1088
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
1089
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
1090
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
1091
|
+
|
|
1092
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
1093
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
1094
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
1095
|
+
|
|
1096
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
1097
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
1098
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1102
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1103
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
1104
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
1105
|
+
}
|
|
1106
|
+
|
|
1107
|
+
inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
|
|
1108
|
+
size_t vl = vlmax<V>();
|
|
1109
|
+
D Cv00 = set_zero<D>();
|
|
1110
|
+
D Cv01 = set_zero<D>();
|
|
1111
|
+
|
|
1112
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
1113
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
1114
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
1115
|
+
|
|
1116
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
1117
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
1118
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
1119
|
+
}
|
|
1120
|
+
|
|
1121
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1122
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
template <int RM, int RN>
|
|
1126
|
+
inline void gemm_bloc(int64_t ii, int64_t jj) {
|
|
1127
|
+
if constexpr (RM == 4) {
|
|
1128
|
+
if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
|
|
1129
|
+
if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
|
|
1130
|
+
if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
|
|
1131
|
+
if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
|
|
1132
|
+
if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
|
|
1133
|
+
if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
|
|
1134
|
+
} else if constexpr (RM == 2) {
|
|
1135
|
+
if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
|
|
1136
|
+
if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
|
|
1137
|
+
}
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
template <int RM, int RN, int BM>
|
|
1141
|
+
NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
|
1142
|
+
GGML_ASSERT(m % (RM * BM) == 0);
|
|
1143
|
+
const int64_t ytiles = m / (RM * BM);
|
|
1144
|
+
const int64_t xtiles = (n + RN -1) / RN;
|
|
1145
|
+
const int64_t jj_RN = (xtiles - (xtiles * RN - n));
|
|
1146
|
+
|
|
1147
|
+
// "round" bloc_size to "nearest" BN
|
|
1148
|
+
const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
|
|
1149
|
+
const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
|
|
1150
|
+
const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
|
|
1151
|
+
const int64_t nb_job = ytiles * NB_BN;
|
|
1152
|
+
|
|
1153
|
+
if (params->ith == 0) {
|
|
1154
|
+
GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
|
1155
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
1156
|
+
ggml_threadpool_chunk_set(params->threadpool, params->nth);
|
|
1157
|
+
}
|
|
1158
|
+
|
|
1159
|
+
ggml_barrier(params->threadpool);
|
|
1160
|
+
|
|
1161
|
+
int64_t job = params->ith;
|
|
1162
|
+
while (job < nb_job) {
|
|
1163
|
+
const int64_t ii = (job % ytiles) * RM * BM;
|
|
1164
|
+
const int64_t jb = job / ytiles;
|
|
1165
|
+
const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
|
|
1166
|
+
const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
|
|
1167
|
+
|
|
1168
|
+
const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
|
|
1169
|
+
const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
|
|
1170
|
+
const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
|
|
1171
|
+
|
|
1172
|
+
for (int64_t bi = 0; bi < BM * RM; bi += RM) {
|
|
1173
|
+
int64_t jj = jj0;
|
|
1174
|
+
for (; jj < jj1; jj += RN) {
|
|
1175
|
+
gemm_bloc<RM, RN>(ii + bi, jj);
|
|
1176
|
+
}
|
|
1177
|
+
if constexpr (RN > 1) {
|
|
1178
|
+
for (; jj < jj2; jj += RN - 1) {
|
|
1179
|
+
gemm_bloc<RM, RN-1>(ii + bi, jj);
|
|
1180
|
+
}
|
|
1181
|
+
}
|
|
1182
|
+
GGML_ASSERT(jj == jj2);
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
job = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
1186
|
+
}
|
|
1187
|
+
|
|
1188
|
+
ggml_barrier(params->threadpool);
|
|
1189
|
+
return;
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
const ggml_compute_params * params;
|
|
1193
|
+
const TA *const A;
|
|
1194
|
+
const TB *const B;
|
|
1195
|
+
TC *const C;
|
|
1196
|
+
const int64_t k;
|
|
1197
|
+
const int64_t lda;
|
|
1198
|
+
const int64_t ldb;
|
|
1199
|
+
const int64_t ldc;
|
|
1200
|
+
};
|
|
1201
|
+
#endif
|
|
1202
|
+
|
|
492
1203
|
//////////////////////////////////////////////////////////////////////////////////////////
|
|
493
1204
|
// QUANT ZERO MATRIX MULTIPLICATION
|
|
494
1205
|
|
|
@@ -1573,95 +2284,35 @@ class tinyBLAS_BF16_PPC {
|
|
|
1573
2284
|
const int nth;
|
|
1574
2285
|
};
|
|
1575
2286
|
|
|
1576
|
-
template <typename TA>
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
float *C, int64_t ldc,
|
|
1583
|
-
int ith, int nth)
|
|
2287
|
+
template <typename TA>
|
|
2288
|
+
tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
|
|
2289
|
+
const TA *A, int64_t lda,
|
|
2290
|
+
const block_q8_0 *B, int64_t ldb,
|
|
2291
|
+
float *C, int64_t ldc,
|
|
2292
|
+
int ith, int nth)
|
|
1584
2293
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
2294
|
+
kc = 64;
|
|
1585
2295
|
}
|
|
1586
2296
|
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
template<int size>
|
|
1602
|
-
inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
|
|
1603
|
-
vector signed int vec_C[4];
|
|
1604
|
-
vector float CA[4] = {0};
|
|
1605
|
-
vector float res[4] = {0};
|
|
1606
|
-
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
1607
|
-
for (int i = 0; i < 4; i++) {
|
|
1608
|
-
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
|
|
1609
|
-
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
|
1610
|
-
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
|
|
1611
|
-
}
|
|
1612
|
-
}
|
|
1613
|
-
/* This function processes quantized data from block_q4_0 elements.
|
|
1614
|
-
* First the we try to extract the two int4 values stored in single int8_t into two signed int8.
|
|
1615
|
-
* And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
|
|
1616
|
-
* Also compute the rowsum which is required to compensate the above conversion. */
|
|
1617
|
-
inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
|
|
1618
|
-
const vector signed char lowMask = vec_splats((signed char)0xF);
|
|
1619
|
-
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
|
1620
|
-
const vector signed char v8 = vec_splats((signed char)0x8);
|
|
1621
|
-
vector signed int vsum = {0};
|
|
1622
|
-
vector signed int vsum2 = {0};
|
|
1623
|
-
c[0] = vec_and(c[1], lowMask);
|
|
1624
|
-
c[1] = vec_sr(c[1], v4);
|
|
1625
|
-
c[0] = vec_sub(c[0], v8);
|
|
1626
|
-
c[1] = vec_sub(c[1], v8);
|
|
1627
|
-
vsum = vec_sum4s(c[0], vsum);
|
|
1628
|
-
vsum2 = vec_sum4s(c[1], vsum2);
|
|
1629
|
-
vsum = vec_add(vsum, vsum2);
|
|
1630
|
-
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1631
|
-
}
|
|
1632
|
-
|
|
1633
|
-
template <typename V1, typename V2>
|
|
1634
|
-
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
|
|
1635
|
-
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
1636
|
-
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
1637
|
-
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
|
1638
|
-
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
|
1639
|
-
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
|
1640
|
-
vector unsigned char xor_vector;
|
|
1641
|
-
uint8_t flip_vec = 0x80;
|
|
1642
|
-
xor_vector = vec_splats(flip_vec);
|
|
1643
|
-
t1 = vec_perm(s1, s2, swiz1);
|
|
1644
|
-
t2 = vec_perm(s1, s2, swiz2);
|
|
1645
|
-
t3 = vec_perm(s3, s4, swiz1);
|
|
1646
|
-
t4 = vec_perm(s3, s4, swiz2);
|
|
1647
|
-
t5 = vec_perm(t1, t3, swiz3);
|
|
1648
|
-
t6 = vec_perm(t1, t3, swiz4);
|
|
1649
|
-
t7 = vec_perm(t2, t4, swiz3);
|
|
1650
|
-
t8 = vec_perm(t2, t4, swiz4);
|
|
1651
|
-
if (flip == true) {
|
|
1652
|
-
t5 = vec_xor(t5, xor_vector);
|
|
1653
|
-
t6 = vec_xor(t6, xor_vector);
|
|
1654
|
-
t7 = vec_xor(t7, xor_vector);
|
|
1655
|
-
t8 = vec_xor(t8, xor_vector);
|
|
2297
|
+
template<typename TA>
|
|
2298
|
+
void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
|
|
2299
|
+
int mc = 64; int nc = 64;
|
|
2300
|
+
if (n % 8 == 0 && n < nc) {
|
|
2301
|
+
nc = n;
|
|
2302
|
+
mc = 32 ;
|
|
2303
|
+
kc = 32;
|
|
2304
|
+
}
|
|
2305
|
+
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
|
|
2306
|
+
if (is_aligned) {
|
|
2307
|
+
this->matmul_tiled_q0(m, n, mc, nc, kc);
|
|
2308
|
+
} else {
|
|
2309
|
+
mnpack(0, m, 0, n);
|
|
1656
2310
|
}
|
|
1657
|
-
vec_xst(t5, 0, vecOffset);
|
|
1658
|
-
vec_xst(t6, 0, vecOffset+16);
|
|
1659
|
-
vec_xst(t7, 0, vecOffset+32);
|
|
1660
|
-
vec_xst(t8, 0, vecOffset+48);
|
|
1661
2311
|
}
|
|
1662
2312
|
|
|
1663
|
-
|
|
1664
|
-
|
|
2313
|
+
template<typename TA>
|
|
2314
|
+
template<int size>
|
|
2315
|
+
void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
|
|
1665
2316
|
int64_t i, j;
|
|
1666
2317
|
TA *aoffset = NULL;
|
|
1667
2318
|
int8_t *vecOffset = NULL;
|
|
@@ -1781,8 +2432,10 @@ class tinyBLAS_Q0_PPC {
|
|
|
1781
2432
|
}
|
|
1782
2433
|
}
|
|
1783
2434
|
}
|
|
2435
|
+
|
|
2436
|
+
template<typename TA>
|
|
1784
2437
|
template<typename VA, typename VB>
|
|
1785
|
-
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
|
2438
|
+
void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
|
1786
2439
|
int64_t i, j;
|
|
1787
2440
|
block_q8_0 *aoffset = NULL;
|
|
1788
2441
|
VA *vecOffset = NULL;
|
|
@@ -1822,7 +2475,6 @@ class tinyBLAS_Q0_PPC {
|
|
|
1822
2475
|
j--;
|
|
1823
2476
|
} while(j > 0);
|
|
1824
2477
|
}
|
|
1825
|
-
|
|
1826
2478
|
if (rows & 4) {
|
|
1827
2479
|
aoffsets[0] = aoffset;
|
|
1828
2480
|
for (int it = 1; it < 4; it++ )
|
|
@@ -1878,7 +2530,8 @@ class tinyBLAS_Q0_PPC {
|
|
|
1878
2530
|
}
|
|
1879
2531
|
}
|
|
1880
2532
|
|
|
1881
|
-
|
|
2533
|
+
template<typename TA>
|
|
2534
|
+
void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1882
2535
|
int m_rem = MIN(m - m0, 16);
|
|
1883
2536
|
int n_rem = MIN(n - n0, 16);
|
|
1884
2537
|
|
|
@@ -1915,7 +2568,8 @@ class tinyBLAS_Q0_PPC {
|
|
|
1915
2568
|
}
|
|
1916
2569
|
|
|
1917
2570
|
|
|
1918
|
-
|
|
2571
|
+
template<typename TA>
|
|
2572
|
+
void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
|
|
1919
2573
|
vec_t vec_A[8], vec_B[16] = {0};
|
|
1920
2574
|
acc_t acc_0, acc_1;
|
|
1921
2575
|
std::array<int, 4> comparray {};
|
|
@@ -1953,14 +2607,15 @@ class tinyBLAS_Q0_PPC {
|
|
|
1953
2607
|
aoffset += lda;
|
|
1954
2608
|
}
|
|
1955
2609
|
}
|
|
1956
|
-
compute
|
|
1957
|
-
compute
|
|
2610
|
+
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
2611
|
+
compute(&acc_1, 0, 4, comparray, vs, fin_res);
|
|
1958
2612
|
}
|
|
1959
2613
|
save_res(ii, jj, 0, fin_res);
|
|
1960
2614
|
save_res(ii, jj+4, 4, fin_res);
|
|
1961
2615
|
}
|
|
1962
2616
|
|
|
1963
|
-
|
|
2617
|
+
template<typename TA>
|
|
2618
|
+
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
|
|
1964
2619
|
vec_t vec_A[16], vec_B[8] = {0};
|
|
1965
2620
|
acc_t acc_0, acc_1;
|
|
1966
2621
|
std::array<int, 8> comparray {};
|
|
@@ -1997,16 +2652,18 @@ class tinyBLAS_Q0_PPC {
|
|
|
1997
2652
|
aoffset += lda;
|
|
1998
2653
|
}
|
|
1999
2654
|
}
|
|
2000
|
-
compute
|
|
2001
|
-
compute
|
|
2655
|
+
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
2656
|
+
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
|
2002
2657
|
}
|
|
2003
2658
|
save_res(ii, jj, 0, fin_res);
|
|
2004
2659
|
save_res(ii+4, jj, 4, fin_res);
|
|
2005
2660
|
}
|
|
2006
2661
|
|
|
2007
|
-
|
|
2662
|
+
template<typename TA>
|
|
2663
|
+
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
|
|
2008
2664
|
vec_t vec_A[16], vec_B[16] = {0};
|
|
2009
2665
|
acc_t acc_0, acc_1, acc_2, acc_3;
|
|
2666
|
+
acc_t acc_4, acc_5, acc_6, acc_7;
|
|
2010
2667
|
std::array<int, 8> comparray {};
|
|
2011
2668
|
vector float fin_res[16] = {0};
|
|
2012
2669
|
vector float vs[16] = {0};
|
|
@@ -2046,10 +2703,10 @@ class tinyBLAS_Q0_PPC {
|
|
|
2046
2703
|
aoffset += lda;
|
|
2047
2704
|
}
|
|
2048
2705
|
}
|
|
2049
|
-
compute
|
|
2050
|
-
compute
|
|
2051
|
-
compute
|
|
2052
|
-
compute
|
|
2706
|
+
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
2707
|
+
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
|
2708
|
+
compute(&acc_2, 0, 8, comparray, vs, fin_res);
|
|
2709
|
+
compute(&acc_3, 4, 12, comparray, vs, fin_res);
|
|
2053
2710
|
}
|
|
2054
2711
|
save_res(ii, jj, 0, fin_res);
|
|
2055
2712
|
save_res(ii+4, jj, 4, fin_res);
|
|
@@ -2057,7 +2714,8 @@ class tinyBLAS_Q0_PPC {
|
|
|
2057
2714
|
save_res(ii+4, jj+4, 12, fin_res);
|
|
2058
2715
|
}
|
|
2059
2716
|
|
|
2060
|
-
|
|
2717
|
+
template<typename TA>
|
|
2718
|
+
void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
|
2061
2719
|
int64_t ytiles = (m - m0) / RM;
|
|
2062
2720
|
int64_t xtiles = (n - n0) / RN;
|
|
2063
2721
|
int64_t tiles = xtiles * ytiles;
|
|
@@ -2125,21 +2783,9 @@ class tinyBLAS_Q0_PPC {
|
|
|
2125
2783
|
}
|
|
2126
2784
|
}
|
|
2127
2785
|
|
|
2128
|
-
template<
|
|
2129
|
-
inline void kernel(int64_t ii, int64_t jj) {
|
|
2130
|
-
if constexpr(RM == 4 && RN == 8) {
|
|
2131
|
-
KERNEL_4x8(ii,jj);
|
|
2132
|
-
} else if constexpr(RM == 8 && RN == 4) {
|
|
2133
|
-
KERNEL_8x4(ii,jj);
|
|
2134
|
-
} else if constexpr(RM == 8 && RN == 8) {
|
|
2135
|
-
KERNEL_8x8(ii,jj);
|
|
2136
|
-
} else {
|
|
2137
|
-
assert(false && "RN/RM values not supported");
|
|
2138
|
-
}
|
|
2139
|
-
}
|
|
2140
|
-
|
|
2786
|
+
template<typename TA>
|
|
2141
2787
|
template <int RM, int RN>
|
|
2142
|
-
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
2788
|
+
NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
2143
2789
|
int64_t ytiles = (m - m0) / RM;
|
|
2144
2790
|
int64_t xtiles = (n - n0) / RN;
|
|
2145
2791
|
int64_t tiles = xtiles * ytiles;
|
|
@@ -2151,20 +2797,12 @@ class tinyBLAS_Q0_PPC {
|
|
|
2151
2797
|
for (int64_t job = start; job < end; ++job) {
|
|
2152
2798
|
int64_t ii = m0 + job / xtiles * RM;
|
|
2153
2799
|
int64_t jj = n0 + job % xtiles * RN;
|
|
2154
|
-
kernel<RM, RN>(ii, jj);
|
|
2800
|
+
this->kernel<RM, RN>(ii, jj);
|
|
2155
2801
|
}
|
|
2156
2802
|
}
|
|
2157
2803
|
|
|
2158
|
-
|
|
2159
|
-
|
|
2160
|
-
float *C;
|
|
2161
|
-
const int64_t k;
|
|
2162
|
-
const int64_t lda;
|
|
2163
|
-
const int64_t ldb;
|
|
2164
|
-
const int64_t ldc;
|
|
2165
|
-
const int ith;
|
|
2166
|
-
const int nth;
|
|
2167
|
-
};
|
|
2804
|
+
template class tinyBLAS_Q0_PPC<block_q4_0>;
|
|
2805
|
+
template class tinyBLAS_Q0_PPC<block_q8_0>;
|
|
2168
2806
|
|
|
2169
2807
|
class tinyBLAS_PPC {
|
|
2170
2808
|
public:
|
|
@@ -2731,6 +3369,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2731
3369
|
params->ith, params->nth};
|
|
2732
3370
|
tb.matmul(m, n);
|
|
2733
3371
|
return true;
|
|
3372
|
+
#elif defined(__riscv_zvfh)
|
|
3373
|
+
#if LMUL == 1
|
|
3374
|
+
tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
|
|
3375
|
+
k, (const float *)A, lda,
|
|
3376
|
+
(const float *)B, ldb,
|
|
3377
|
+
(float *)C, ldc};
|
|
3378
|
+
#elif LMUL == 2
|
|
3379
|
+
tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
|
|
3380
|
+
k, (const float *)A, lda,
|
|
3381
|
+
(const float *)B, ldb,
|
|
3382
|
+
(float *)C, ldc};
|
|
3383
|
+
#else // LMUL = 4
|
|
3384
|
+
tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
|
|
3385
|
+
k, (const float *)A, lda,
|
|
3386
|
+
(const float *)B, ldb,
|
|
3387
|
+
(float *)C, ldc};
|
|
3388
|
+
#endif
|
|
3389
|
+
return tb.matmul(m, n);
|
|
2734
3390
|
#else
|
|
2735
3391
|
return false;
|
|
2736
3392
|
#endif
|
|
@@ -2773,6 +3429,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2773
3429
|
tb.matmul(m, n);
|
|
2774
3430
|
return true;
|
|
2775
3431
|
}
|
|
3432
|
+
#elif defined(__riscv_zvfbfwma)
|
|
3433
|
+
#if LMUL == 1
|
|
3434
|
+
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3435
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3436
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3437
|
+
(float *)C, ldc};
|
|
3438
|
+
#elif LMUL == 2
|
|
3439
|
+
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3440
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3441
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3442
|
+
(float *)C, ldc};
|
|
3443
|
+
#else // LMUL = 4
|
|
3444
|
+
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3445
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3446
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3447
|
+
(float *)C, ldc};
|
|
3448
|
+
#endif
|
|
3449
|
+
return tb.matmul(m, n);
|
|
2776
3450
|
#endif
|
|
2777
3451
|
return false;
|
|
2778
3452
|
}
|
|
@@ -2822,6 +3496,26 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2822
3496
|
(float *)C, ldc};
|
|
2823
3497
|
return tb.matmul(m, n);
|
|
2824
3498
|
}
|
|
3499
|
+
#elif defined(__riscv_zvfh)
|
|
3500
|
+
if (Btype == GGML_TYPE_F16) {
|
|
3501
|
+
#if LMUL == 1
|
|
3502
|
+
tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
3503
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
3504
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3505
|
+
(float *)C, ldc};
|
|
3506
|
+
#elif LMUL == 2
|
|
3507
|
+
tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
3508
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
3509
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3510
|
+
(float *)C, ldc};
|
|
3511
|
+
#else // LMUL = 4
|
|
3512
|
+
tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
3513
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
3514
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3515
|
+
(float *)C, ldc};
|
|
3516
|
+
#endif
|
|
3517
|
+
return tb.matmul(m, n);
|
|
3518
|
+
}
|
|
2825
3519
|
#endif
|
|
2826
3520
|
return false;
|
|
2827
3521
|
}
|