whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -69,6 +69,10 @@
|
|
|
69
69
|
#define VECTOR_REGISTERS 16
|
|
70
70
|
#endif
|
|
71
71
|
|
|
72
|
+
#if defined(__riscv_v_intrinsic)
|
|
73
|
+
#define LMUL 4
|
|
74
|
+
#endif
|
|
75
|
+
|
|
72
76
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
|
73
77
|
|
|
74
78
|
namespace {
|
|
@@ -176,6 +180,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
|
|
176
180
|
}
|
|
177
181
|
#endif
|
|
178
182
|
|
|
183
|
+
#if defined(__riscv_zvfh)
|
|
184
|
+
template <>
|
|
185
|
+
inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
|
186
|
+
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
187
|
+
}
|
|
188
|
+
inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
|
189
|
+
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
190
|
+
}
|
|
191
|
+
inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
|
192
|
+
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
193
|
+
}
|
|
194
|
+
inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
|
195
|
+
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
196
|
+
}
|
|
197
|
+
inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
|
198
|
+
return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
199
|
+
}
|
|
200
|
+
inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
|
201
|
+
return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
202
|
+
}
|
|
203
|
+
inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
|
204
|
+
return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
205
|
+
}
|
|
206
|
+
inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
|
207
|
+
return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
|
208
|
+
}
|
|
209
|
+
#endif
|
|
210
|
+
|
|
211
|
+
#if defined(__riscv_zvfbfwma)
|
|
212
|
+
inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
|
213
|
+
return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
|
214
|
+
}
|
|
215
|
+
inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
|
216
|
+
return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
|
217
|
+
}
|
|
218
|
+
inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
|
219
|
+
return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
|
220
|
+
}
|
|
221
|
+
#endif
|
|
222
|
+
|
|
179
223
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
180
224
|
// VECTORIZED HORIZONTAL SUM
|
|
181
225
|
|
|
@@ -228,6 +272,25 @@ inline float hsum(__m512 x) {
|
|
|
228
272
|
}
|
|
229
273
|
#endif // __AVX512F__
|
|
230
274
|
|
|
275
|
+
#if defined(__riscv_zvfh)
|
|
276
|
+
inline float hsum(vfloat32m1_t x) {
|
|
277
|
+
return __riscv_vfmv_f_s_f32m1_f32(
|
|
278
|
+
__riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
|
|
279
|
+
}
|
|
280
|
+
inline float hsum(vfloat32m2_t x) {
|
|
281
|
+
return __riscv_vfmv_f_s_f32m1_f32(
|
|
282
|
+
__riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
|
|
283
|
+
}
|
|
284
|
+
inline float hsum(vfloat32m4_t x) {
|
|
285
|
+
return __riscv_vfmv_f_s_f32m1_f32(
|
|
286
|
+
__riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
|
|
287
|
+
}
|
|
288
|
+
inline float hsum(vfloat32m8_t x) {
|
|
289
|
+
return __riscv_vfmv_f_s_f32m1_f32(
|
|
290
|
+
__riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
|
|
291
|
+
}
|
|
292
|
+
#endif
|
|
293
|
+
|
|
231
294
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
232
295
|
// VECTORIZED MEMORY LOADING
|
|
233
296
|
|
|
@@ -316,6 +379,88 @@ template <> inline __m256bh load(const float *p) {
|
|
|
316
379
|
}
|
|
317
380
|
#endif
|
|
318
381
|
|
|
382
|
+
#if defined(__riscv_zvfh)
|
|
383
|
+
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
|
|
384
|
+
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
|
|
385
|
+
}
|
|
386
|
+
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
|
|
387
|
+
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
|
|
388
|
+
}
|
|
389
|
+
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
|
390
|
+
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
|
|
391
|
+
}
|
|
392
|
+
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
|
393
|
+
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
|
394
|
+
}
|
|
395
|
+
template <> inline vfloat32m1_t load(const float *p) {
|
|
396
|
+
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
|
|
397
|
+
}
|
|
398
|
+
template <> inline vfloat32m2_t load(const float *p) {
|
|
399
|
+
return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
|
|
400
|
+
}
|
|
401
|
+
template <> inline vfloat32m4_t load(const float *p) {
|
|
402
|
+
return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
|
|
403
|
+
}
|
|
404
|
+
template <> inline vfloat32m8_t load(const float *p) {
|
|
405
|
+
return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
|
|
406
|
+
}
|
|
407
|
+
#endif
|
|
408
|
+
|
|
409
|
+
#if defined(__riscv_zvfbfwma)
|
|
410
|
+
template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
|
|
411
|
+
return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
|
|
412
|
+
}
|
|
413
|
+
template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
|
|
414
|
+
return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
|
|
415
|
+
}
|
|
416
|
+
template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
|
|
417
|
+
return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
|
|
418
|
+
}
|
|
419
|
+
#endif
|
|
420
|
+
|
|
421
|
+
#if defined(__riscv_zvfh)
|
|
422
|
+
template <typename T> T set_zero();
|
|
423
|
+
|
|
424
|
+
template <> inline vfloat16mf2_t set_zero() {
|
|
425
|
+
return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
|
|
426
|
+
}
|
|
427
|
+
template <> inline vfloat16m1_t set_zero() {
|
|
428
|
+
return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
|
|
429
|
+
}
|
|
430
|
+
template <> inline vfloat16m2_t set_zero() {
|
|
431
|
+
return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
|
|
432
|
+
}
|
|
433
|
+
template <> inline vfloat16m4_t set_zero() {
|
|
434
|
+
return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
|
|
435
|
+
}
|
|
436
|
+
template <> inline vfloat32m1_t set_zero() {
|
|
437
|
+
return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
|
|
438
|
+
}
|
|
439
|
+
template <> inline vfloat32m2_t set_zero() {
|
|
440
|
+
return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
|
|
441
|
+
}
|
|
442
|
+
template <> inline vfloat32m4_t set_zero() {
|
|
443
|
+
return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
|
|
444
|
+
}
|
|
445
|
+
template <> inline vfloat32m8_t set_zero() {
|
|
446
|
+
return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
|
|
447
|
+
}
|
|
448
|
+
#endif
|
|
449
|
+
|
|
450
|
+
#if defined(__riscv_v_intrinsic)
|
|
451
|
+
template <typename T> size_t vlmax() {
|
|
452
|
+
if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
|
453
|
+
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
|
454
|
+
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
|
455
|
+
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
|
456
|
+
else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
|
457
|
+
else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
|
|
458
|
+
else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
|
|
459
|
+
else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
|
|
460
|
+
return 0;
|
|
461
|
+
}
|
|
462
|
+
#endif
|
|
463
|
+
|
|
319
464
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
320
465
|
// FLOATING POINT MATRIX MULTIPLICATION
|
|
321
466
|
|
|
@@ -388,7 +533,7 @@ class tinyBLAS {
|
|
|
388
533
|
if constexpr (RN > 1) {
|
|
389
534
|
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
|
390
535
|
} else {
|
|
391
|
-
GGML_LOG_ERROR("mnpack<%d, %d>
|
|
536
|
+
GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
|
|
392
537
|
GGML_ASSERT(false); // we have miss something.
|
|
393
538
|
}
|
|
394
539
|
}
|
|
@@ -489,6 +634,573 @@ class tinyBLAS {
|
|
|
489
634
|
const int64_t ldc;
|
|
490
635
|
};
|
|
491
636
|
|
|
637
|
+
#if defined(__riscv_v_intrinsic)
|
|
638
|
+
template <typename D, typename V, typename TA, typename TB, typename TC>
|
|
639
|
+
class tinyBLAS_RVV {
|
|
640
|
+
public:
|
|
641
|
+
tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
|
|
642
|
+
const TA *A, int64_t lda,
|
|
643
|
+
const TB *B, int64_t ldb,
|
|
644
|
+
TC *C, int64_t ldc)
|
|
645
|
+
: params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
|
|
646
|
+
}
|
|
647
|
+
|
|
648
|
+
bool matmul(int64_t m, int64_t n) {
|
|
649
|
+
if (k % vlmax<V>() != 0) {
|
|
650
|
+
return false;
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
#if LMUL == 1
|
|
654
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
655
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
656
|
+
mnpack<4, 6, 4>(m, n, SIZE_N, 12);
|
|
657
|
+
return true;
|
|
658
|
+
}
|
|
659
|
+
if (m % 8 == 0 ) {
|
|
660
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
661
|
+
mnpack<4, 6, 2>(m, n, SIZE_N, 12);
|
|
662
|
+
return true;
|
|
663
|
+
}
|
|
664
|
+
if (m % 4 == 0) {
|
|
665
|
+
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
|
666
|
+
mnpack<4, 6, 1>(m, n, SIZE_N, 12);
|
|
667
|
+
return true;
|
|
668
|
+
}
|
|
669
|
+
#elif LMUL == 2
|
|
670
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
671
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
672
|
+
mnpack<4, 3, 4>(m, n, SIZE_N, 24);
|
|
673
|
+
return true;
|
|
674
|
+
}
|
|
675
|
+
if (m % 8 == 0 ) {
|
|
676
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
677
|
+
mnpack<4, 3, 2>(m, n, SIZE_N, 24);
|
|
678
|
+
return true;
|
|
679
|
+
}
|
|
680
|
+
if (m % 4 == 0) {
|
|
681
|
+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
|
682
|
+
mnpack<4, 3, 1>(m, n, SIZE_N, 24);
|
|
683
|
+
return true;
|
|
684
|
+
}
|
|
685
|
+
#else // LMUL = 4
|
|
686
|
+
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
|
687
|
+
const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
688
|
+
mnpack<2, 2, 8>(m, n, SIZE_N, 36);
|
|
689
|
+
return true;
|
|
690
|
+
}
|
|
691
|
+
if (m % 8 == 0 ) {
|
|
692
|
+
const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
693
|
+
mnpack<2, 2, 4>(m, n, SIZE_N, 36);
|
|
694
|
+
return true;
|
|
695
|
+
}
|
|
696
|
+
if (m % 4 == 0) {
|
|
697
|
+
const int64_t SIZE_N = BLOCK_SIZE<2>(n);
|
|
698
|
+
mnpack<2, 2, 2>(m, n, SIZE_N, 36);
|
|
699
|
+
return true;
|
|
700
|
+
}
|
|
701
|
+
#endif
|
|
702
|
+
return false;
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
private:
|
|
706
|
+
template<int RM, int RN, int BM>
|
|
707
|
+
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
|
|
708
|
+
if (SIZE_N == RN) {
|
|
709
|
+
return gemm<RM, RN, BM>(m, n, BN);
|
|
710
|
+
}
|
|
711
|
+
if constexpr (RN > 1) {
|
|
712
|
+
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
|
713
|
+
} else {
|
|
714
|
+
GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
|
|
715
|
+
GGML_ASSERT(false); // we have miss something.
|
|
716
|
+
}
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
|
|
720
|
+
size_t vl = vlmax<V>();
|
|
721
|
+
D Cv00 = set_zero<D>();
|
|
722
|
+
D Cv01 = set_zero<D>();
|
|
723
|
+
D Cv02 = set_zero<D>();
|
|
724
|
+
D Cv03 = set_zero<D>();
|
|
725
|
+
D Cv10 = set_zero<D>();
|
|
726
|
+
D Cv11 = set_zero<D>();
|
|
727
|
+
D Cv12 = set_zero<D>();
|
|
728
|
+
D Cv13 = set_zero<D>();
|
|
729
|
+
D Cv20 = set_zero<D>();
|
|
730
|
+
D Cv21 = set_zero<D>();
|
|
731
|
+
D Cv22 = set_zero<D>();
|
|
732
|
+
D Cv23 = set_zero<D>();
|
|
733
|
+
D Cv30 = set_zero<D>();
|
|
734
|
+
D Cv31 = set_zero<D>();
|
|
735
|
+
D Cv32 = set_zero<D>();
|
|
736
|
+
D Cv33 = set_zero<D>();
|
|
737
|
+
D Cv40 = set_zero<D>();
|
|
738
|
+
D Cv41 = set_zero<D>();
|
|
739
|
+
D Cv42 = set_zero<D>();
|
|
740
|
+
D Cv43 = set_zero<D>();
|
|
741
|
+
D Cv50 = set_zero<D>();
|
|
742
|
+
D Cv51 = set_zero<D>();
|
|
743
|
+
D Cv52 = set_zero<D>();
|
|
744
|
+
D Cv53 = set_zero<D>();
|
|
745
|
+
|
|
746
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
747
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
748
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
749
|
+
V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
750
|
+
V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
751
|
+
V Bv4 = load<V>(B + ldb * (jj + 4) + l);
|
|
752
|
+
V Bv5 = load<V>(B + ldb * (jj + 5) + l);
|
|
753
|
+
|
|
754
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
755
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
756
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
757
|
+
Cv20 = madd(Av0, Bv2, Cv20);
|
|
758
|
+
Cv30 = madd(Av0, Bv3, Cv30);
|
|
759
|
+
Cv40 = madd(Av0, Bv4, Cv40);
|
|
760
|
+
Cv50 = madd(Av0, Bv5, Cv50);
|
|
761
|
+
|
|
762
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
763
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
764
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
765
|
+
Cv21 = madd(Av1, Bv2, Cv21);
|
|
766
|
+
Cv31 = madd(Av1, Bv3, Cv31);
|
|
767
|
+
Cv41 = madd(Av1, Bv4, Cv41);
|
|
768
|
+
Cv51 = madd(Av1, Bv5, Cv51);
|
|
769
|
+
|
|
770
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
771
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
772
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
773
|
+
Cv22 = madd(Av2, Bv2, Cv22);
|
|
774
|
+
Cv32 = madd(Av2, Bv3, Cv32);
|
|
775
|
+
Cv42 = madd(Av2, Bv4, Cv42);
|
|
776
|
+
Cv52 = madd(Av2, Bv5, Cv52);
|
|
777
|
+
|
|
778
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
779
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
780
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
781
|
+
Cv23 = madd(Av3, Bv2, Cv23);
|
|
782
|
+
Cv33 = madd(Av3, Bv3, Cv33);
|
|
783
|
+
Cv43 = madd(Av3, Bv4, Cv43);
|
|
784
|
+
Cv53 = madd(Av3, Bv5, Cv53);
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
788
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
789
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
790
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
791
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
792
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
793
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
794
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
795
|
+
C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
796
|
+
C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
797
|
+
C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
798
|
+
C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
799
|
+
C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
800
|
+
C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
801
|
+
C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
802
|
+
C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
803
|
+
C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
|
|
804
|
+
C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
|
|
805
|
+
C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
|
|
806
|
+
C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
|
|
807
|
+
C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
|
|
808
|
+
C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
|
|
809
|
+
C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
|
|
810
|
+
C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
|
|
814
|
+
size_t vl = vlmax<V>();
|
|
815
|
+
D Cv00 = set_zero<D>();
|
|
816
|
+
D Cv01 = set_zero<D>();
|
|
817
|
+
D Cv02 = set_zero<D>();
|
|
818
|
+
D Cv03 = set_zero<D>();
|
|
819
|
+
D Cv10 = set_zero<D>();
|
|
820
|
+
D Cv11 = set_zero<D>();
|
|
821
|
+
D Cv12 = set_zero<D>();
|
|
822
|
+
D Cv13 = set_zero<D>();
|
|
823
|
+
D Cv20 = set_zero<D>();
|
|
824
|
+
D Cv21 = set_zero<D>();
|
|
825
|
+
D Cv22 = set_zero<D>();
|
|
826
|
+
D Cv23 = set_zero<D>();
|
|
827
|
+
D Cv30 = set_zero<D>();
|
|
828
|
+
D Cv31 = set_zero<D>();
|
|
829
|
+
D Cv32 = set_zero<D>();
|
|
830
|
+
D Cv33 = set_zero<D>();
|
|
831
|
+
D Cv40 = set_zero<D>();
|
|
832
|
+
D Cv41 = set_zero<D>();
|
|
833
|
+
D Cv42 = set_zero<D>();
|
|
834
|
+
D Cv43 = set_zero<D>();
|
|
835
|
+
|
|
836
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
837
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
838
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
839
|
+
V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
840
|
+
V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
841
|
+
V Bv4 = load<V>(B + ldb * (jj + 4) + l);
|
|
842
|
+
|
|
843
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
844
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
845
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
846
|
+
Cv20 = madd(Av0, Bv2, Cv20);
|
|
847
|
+
Cv30 = madd(Av0, Bv3, Cv30);
|
|
848
|
+
Cv40 = madd(Av0, Bv4, Cv40);
|
|
849
|
+
|
|
850
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
851
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
852
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
853
|
+
Cv21 = madd(Av1, Bv2, Cv21);
|
|
854
|
+
Cv31 = madd(Av1, Bv3, Cv31);
|
|
855
|
+
Cv41 = madd(Av1, Bv4, Cv41);
|
|
856
|
+
|
|
857
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
858
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
859
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
860
|
+
Cv22 = madd(Av2, Bv2, Cv22);
|
|
861
|
+
Cv32 = madd(Av2, Bv3, Cv32);
|
|
862
|
+
Cv42 = madd(Av2, Bv4, Cv42);
|
|
863
|
+
|
|
864
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
865
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
866
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
867
|
+
Cv23 = madd(Av3, Bv2, Cv23);
|
|
868
|
+
Cv33 = madd(Av3, Bv3, Cv33);
|
|
869
|
+
Cv43 = madd(Av3, Bv4, Cv43);
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
873
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
874
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
875
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
876
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
877
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
878
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
879
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
880
|
+
C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
881
|
+
C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
882
|
+
C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
883
|
+
C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
884
|
+
C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
885
|
+
C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
886
|
+
C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
887
|
+
C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
888
|
+
C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
|
|
889
|
+
C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
|
|
890
|
+
C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
|
|
891
|
+
C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
|
|
895
|
+
size_t vl = vlmax<V>();
|
|
896
|
+
D Cv00 = set_zero<D>();
|
|
897
|
+
D Cv01 = set_zero<D>();
|
|
898
|
+
D Cv02 = set_zero<D>();
|
|
899
|
+
D Cv03 = set_zero<D>();
|
|
900
|
+
D Cv10 = set_zero<D>();
|
|
901
|
+
D Cv11 = set_zero<D>();
|
|
902
|
+
D Cv12 = set_zero<D>();
|
|
903
|
+
D Cv13 = set_zero<D>();
|
|
904
|
+
D Cv20 = set_zero<D>();
|
|
905
|
+
D Cv21 = set_zero<D>();
|
|
906
|
+
D Cv22 = set_zero<D>();
|
|
907
|
+
D Cv23 = set_zero<D>();
|
|
908
|
+
D Cv30 = set_zero<D>();
|
|
909
|
+
D Cv31 = set_zero<D>();
|
|
910
|
+
D Cv32 = set_zero<D>();
|
|
911
|
+
D Cv33 = set_zero<D>();
|
|
912
|
+
|
|
913
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
914
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
915
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
916
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
917
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
918
|
+
|
|
919
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
920
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
921
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
922
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
923
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
924
|
+
|
|
925
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
926
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
927
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
928
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
929
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
930
|
+
|
|
931
|
+
V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
932
|
+
Cv20 = madd(Av0, Bv2, Cv20);
|
|
933
|
+
Cv21 = madd(Av1, Bv2, Cv21);
|
|
934
|
+
Cv22 = madd(Av2, Bv2, Cv22);
|
|
935
|
+
Cv23 = madd(Av3, Bv2, Cv23);
|
|
936
|
+
|
|
937
|
+
V Bv3 = load<V>(B + ldb * (jj + 3) + l);
|
|
938
|
+
Cv30 = madd(Av0, Bv3, Cv30);
|
|
939
|
+
Cv31 = madd(Av1, Bv3, Cv31);
|
|
940
|
+
Cv32 = madd(Av2, Bv3, Cv32);
|
|
941
|
+
Cv33 = madd(Av3, Bv3, Cv33);
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
945
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
946
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
947
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
948
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
949
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
950
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
951
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
952
|
+
C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
953
|
+
C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
954
|
+
C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
955
|
+
C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
956
|
+
C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
|
|
957
|
+
C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
|
|
958
|
+
C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
|
|
959
|
+
C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
|
|
960
|
+
}
|
|
961
|
+
|
|
962
|
+
inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
|
|
963
|
+
size_t vl = vlmax<V>();
|
|
964
|
+
D Cv00 = set_zero<D>();
|
|
965
|
+
D Cv01 = set_zero<D>();
|
|
966
|
+
D Cv02 = set_zero<D>();
|
|
967
|
+
D Cv03 = set_zero<D>();
|
|
968
|
+
D Cv10 = set_zero<D>();
|
|
969
|
+
D Cv11 = set_zero<D>();
|
|
970
|
+
D Cv12 = set_zero<D>();
|
|
971
|
+
D Cv13 = set_zero<D>();
|
|
972
|
+
D Cv20 = set_zero<D>();
|
|
973
|
+
D Cv21 = set_zero<D>();
|
|
974
|
+
D Cv22 = set_zero<D>();
|
|
975
|
+
D Cv23 = set_zero<D>();
|
|
976
|
+
|
|
977
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
978
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
979
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
980
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
981
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
982
|
+
|
|
983
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
984
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
985
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
986
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
987
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
988
|
+
|
|
989
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
990
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
991
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
992
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
993
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
994
|
+
|
|
995
|
+
V Bv2 = load<V>(B + ldb * (jj + 2) + l);
|
|
996
|
+
Cv20 = madd(Av0, Bv2, Cv20);
|
|
997
|
+
Cv21 = madd(Av1, Bv2, Cv21);
|
|
998
|
+
Cv22 = madd(Av2, Bv2, Cv22);
|
|
999
|
+
Cv23 = madd(Av3, Bv2, Cv23);
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1003
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1004
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
1005
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
1006
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
1007
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
1008
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
1009
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
1010
|
+
C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
|
|
1011
|
+
C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
|
|
1012
|
+
C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
|
|
1013
|
+
C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
|
|
1017
|
+
size_t vl = vlmax<V>();
|
|
1018
|
+
D Cv00 = set_zero<D>();
|
|
1019
|
+
D Cv01 = set_zero<D>();
|
|
1020
|
+
D Cv02 = set_zero<D>();
|
|
1021
|
+
D Cv03 = set_zero<D>();
|
|
1022
|
+
D Cv10 = set_zero<D>();
|
|
1023
|
+
D Cv11 = set_zero<D>();
|
|
1024
|
+
D Cv12 = set_zero<D>();
|
|
1025
|
+
D Cv13 = set_zero<D>();
|
|
1026
|
+
|
|
1027
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
1028
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
1029
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
1030
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
1031
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
1032
|
+
|
|
1033
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
1034
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
1035
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
1036
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
1037
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
1038
|
+
|
|
1039
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
1040
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
1041
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
1042
|
+
Cv12 = madd(Av2, Bv1, Cv12);
|
|
1043
|
+
Cv13 = madd(Av3, Bv1, Cv13);
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1047
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1048
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
1049
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
1050
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
1051
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
1052
|
+
C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
|
|
1053
|
+
C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
|
|
1057
|
+
size_t vl = vlmax<V>();
|
|
1058
|
+
D Cv00 = set_zero<D>();
|
|
1059
|
+
D Cv01 = set_zero<D>();
|
|
1060
|
+
D Cv02 = set_zero<D>();
|
|
1061
|
+
D Cv03 = set_zero<D>();
|
|
1062
|
+
|
|
1063
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
1064
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
1065
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
1066
|
+
V Av2 = load<V>(A + lda * (ii + 2) + l);
|
|
1067
|
+
V Av3 = load<V>(A + lda * (ii + 3) + l);
|
|
1068
|
+
|
|
1069
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
1070
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
1071
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
1072
|
+
Cv02 = madd(Av2, Bv0, Cv02);
|
|
1073
|
+
Cv03 = madd(Av3, Bv0, Cv03);
|
|
1074
|
+
}
|
|
1075
|
+
|
|
1076
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1077
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1078
|
+
C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
|
|
1079
|
+
C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
|
|
1080
|
+
}
|
|
1081
|
+
|
|
1082
|
+
inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
|
|
1083
|
+
size_t vl = vlmax<V>();
|
|
1084
|
+
D Cv00 = set_zero<D>();
|
|
1085
|
+
D Cv01 = set_zero<D>();
|
|
1086
|
+
D Cv10 = set_zero<D>();
|
|
1087
|
+
D Cv11 = set_zero<D>();
|
|
1088
|
+
|
|
1089
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
1090
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
1091
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
1092
|
+
|
|
1093
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
1094
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
1095
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
1096
|
+
|
|
1097
|
+
V Bv1 = load<V>(B + ldb * (jj + 1) + l);
|
|
1098
|
+
Cv10 = madd(Av0, Bv1, Cv10);
|
|
1099
|
+
Cv11 = madd(Av1, Bv1, Cv11);
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1103
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1104
|
+
C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
|
|
1105
|
+
C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
|
|
1106
|
+
}
|
|
1107
|
+
|
|
1108
|
+
inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
|
|
1109
|
+
size_t vl = vlmax<V>();
|
|
1110
|
+
D Cv00 = set_zero<D>();
|
|
1111
|
+
D Cv01 = set_zero<D>();
|
|
1112
|
+
|
|
1113
|
+
for (int64_t l = 0; l < k; l += vl) {
|
|
1114
|
+
V Av0 = load<V>(A + lda * (ii + 0) + l);
|
|
1115
|
+
V Av1 = load<V>(A + lda * (ii + 1) + l);
|
|
1116
|
+
|
|
1117
|
+
V Bv0 = load<V>(B + ldb * (jj + 0) + l);
|
|
1118
|
+
Cv00 = madd(Av0, Bv0, Cv00);
|
|
1119
|
+
Cv01 = madd(Av1, Bv0, Cv01);
|
|
1120
|
+
}
|
|
1121
|
+
|
|
1122
|
+
C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
|
|
1123
|
+
C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
|
|
1124
|
+
}
|
|
1125
|
+
|
|
1126
|
+
template <int RM, int RN>
|
|
1127
|
+
inline void gemm_bloc(int64_t ii, int64_t jj) {
|
|
1128
|
+
if constexpr (RM == 4) {
|
|
1129
|
+
if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
|
|
1130
|
+
if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
|
|
1131
|
+
if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
|
|
1132
|
+
if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
|
|
1133
|
+
if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
|
|
1134
|
+
if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
|
|
1135
|
+
} else if constexpr (RM == 2) {
|
|
1136
|
+
if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
|
|
1137
|
+
if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
|
|
1138
|
+
}
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
template <int RM, int RN, int BM>
|
|
1142
|
+
NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
|
1143
|
+
GGML_ASSERT(m % (RM * BM) == 0);
|
|
1144
|
+
const int64_t ytiles = m / (RM * BM);
|
|
1145
|
+
const int64_t xtiles = (n + RN -1) / RN;
|
|
1146
|
+
const int64_t jj_RN = (xtiles - (xtiles * RN - n));
|
|
1147
|
+
|
|
1148
|
+
// "round" bloc_size to "nearest" BN
|
|
1149
|
+
const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
|
|
1150
|
+
const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
|
|
1151
|
+
const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
|
|
1152
|
+
const int64_t nb_job = ytiles * NB_BN;
|
|
1153
|
+
|
|
1154
|
+
if (params->ith == 0) {
|
|
1155
|
+
GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
|
1156
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
1157
|
+
ggml_threadpool_chunk_set(params->threadpool, params->nth);
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
ggml_barrier(params->threadpool);
|
|
1161
|
+
|
|
1162
|
+
int64_t job = params->ith;
|
|
1163
|
+
while (job < nb_job) {
|
|
1164
|
+
const int64_t ii = (job % ytiles) * RM * BM;
|
|
1165
|
+
const int64_t jb = job / ytiles;
|
|
1166
|
+
const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
|
|
1167
|
+
const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
|
|
1168
|
+
|
|
1169
|
+
const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
|
|
1170
|
+
const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
|
|
1171
|
+
const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
|
|
1172
|
+
|
|
1173
|
+
for (int64_t bi = 0; bi < BM * RM; bi += RM) {
|
|
1174
|
+
int64_t jj = jj0;
|
|
1175
|
+
for (; jj < jj1; jj += RN) {
|
|
1176
|
+
gemm_bloc<RM, RN>(ii + bi, jj);
|
|
1177
|
+
}
|
|
1178
|
+
if constexpr (RN > 1) {
|
|
1179
|
+
for (; jj < jj2; jj += RN - 1) {
|
|
1180
|
+
gemm_bloc<RM, RN-1>(ii + bi, jj);
|
|
1181
|
+
}
|
|
1182
|
+
}
|
|
1183
|
+
GGML_ASSERT(jj == jj2);
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
job = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
ggml_barrier(params->threadpool);
|
|
1190
|
+
return;
|
|
1191
|
+
}
|
|
1192
|
+
|
|
1193
|
+
const ggml_compute_params * params;
|
|
1194
|
+
const TA *const A;
|
|
1195
|
+
const TB *const B;
|
|
1196
|
+
TC *const C;
|
|
1197
|
+
const int64_t k;
|
|
1198
|
+
const int64_t lda;
|
|
1199
|
+
const int64_t ldb;
|
|
1200
|
+
const int64_t ldc;
|
|
1201
|
+
};
|
|
1202
|
+
#endif
|
|
1203
|
+
|
|
492
1204
|
//////////////////////////////////////////////////////////////////////////////////////////
|
|
493
1205
|
// QUANT ZERO MATRIX MULTIPLICATION
|
|
494
1206
|
|
|
@@ -1086,10 +1798,27 @@ class tinyBLAS_Q0_AVX {
|
|
|
1086
1798
|
} \
|
|
1087
1799
|
} \
|
|
1088
1800
|
|
|
1801
|
+
template<typename T>
|
|
1802
|
+
struct mma_instr;
|
|
1803
|
+
|
|
1804
|
+
template<>
|
|
1805
|
+
struct mma_instr<ggml_bf16_t> {
|
|
1806
|
+
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
|
|
1807
|
+
__builtin_mma_xvbf16ger2pp(acc, a, b);
|
|
1808
|
+
}
|
|
1809
|
+
};
|
|
1810
|
+
|
|
1811
|
+
template<>
|
|
1812
|
+
struct mma_instr<ggml_fp16_t> {
|
|
1813
|
+
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
|
|
1814
|
+
__builtin_mma_xvf16ger2pp(acc, a, b);
|
|
1815
|
+
}
|
|
1816
|
+
};
|
|
1817
|
+
|
|
1089
1818
|
template <typename TA, typename TB, typename TC>
|
|
1090
|
-
class
|
|
1819
|
+
class tinyBLAS_HP16_PPC {
|
|
1091
1820
|
public:
|
|
1092
|
-
|
|
1821
|
+
tinyBLAS_HP16_PPC(int64_t k,
|
|
1093
1822
|
const TA *A, int64_t lda,
|
|
1094
1823
|
const TB *B, int64_t ldb,
|
|
1095
1824
|
TC *C, int64_t ldc,
|
|
@@ -1407,8 +2136,8 @@ class tinyBLAS_BF16_PPC {
|
|
|
1407
2136
|
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
|
|
1408
2137
|
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
|
|
1409
2138
|
for (int x = 0; x < 4; x++) {
|
|
1410
|
-
|
|
1411
|
-
|
|
2139
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2140
|
+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
|
1412
2141
|
}
|
|
1413
2142
|
}
|
|
1414
2143
|
SAVE_ACC(&acc_0, ii, jj);
|
|
@@ -1424,8 +2153,8 @@ class tinyBLAS_BF16_PPC {
|
|
|
1424
2153
|
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
|
|
1425
2154
|
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
|
1426
2155
|
for (int x = 0; x < 4; x++) {
|
|
1427
|
-
|
|
1428
|
-
|
|
2156
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2157
|
+
mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
|
|
1429
2158
|
}
|
|
1430
2159
|
}
|
|
1431
2160
|
SAVE_ACC(&acc_0, ii, jj);
|
|
@@ -1444,10 +2173,10 @@ class tinyBLAS_BF16_PPC {
|
|
|
1444
2173
|
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
|
|
1445
2174
|
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
|
|
1446
2175
|
for (int x = 0; x < 4; x++) {
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
2176
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2177
|
+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
|
2178
|
+
mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
|
|
2179
|
+
mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
|
|
1451
2180
|
}
|
|
1452
2181
|
}
|
|
1453
2182
|
|
|
@@ -1478,7 +2207,7 @@ class tinyBLAS_BF16_PPC {
|
|
|
1478
2207
|
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
|
|
1479
2208
|
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
|
|
1480
2209
|
for (int x = 0; x<2; x++) {
|
|
1481
|
-
|
|
2210
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
1482
2211
|
}
|
|
1483
2212
|
}
|
|
1484
2213
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
@@ -1513,8 +2242,8 @@ class tinyBLAS_BF16_PPC {
|
|
|
1513
2242
|
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
|
|
1514
2243
|
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
|
|
1515
2244
|
for (int x = 0; x<4; x++) {
|
|
1516
|
-
|
|
1517
|
-
|
|
2245
|
+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
|
2246
|
+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
|
1518
2247
|
}
|
|
1519
2248
|
}
|
|
1520
2249
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
@@ -1577,44 +2306,91 @@ template <typename TA>
|
|
|
1577
2306
|
class tinyBLAS_Q0_PPC {
|
|
1578
2307
|
public:
|
|
1579
2308
|
tinyBLAS_Q0_PPC(int64_t k,
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
2309
|
+
const TA * A, int64_t lda,
|
|
2310
|
+
const block_q8_0 * B, int64_t ldb,
|
|
2311
|
+
float * C, int64_t ldc,
|
|
2312
|
+
int ith, int nth)
|
|
1584
2313
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
1585
2314
|
}
|
|
1586
2315
|
|
|
1587
2316
|
void matmul(int64_t m, int64_t n) {
|
|
1588
|
-
|
|
2317
|
+
const int64_t mc = 64;
|
|
2318
|
+
const int64_t kc = 64;
|
|
2319
|
+
int64_t nc = 64;
|
|
2320
|
+
int64_t n_aligned = 0;
|
|
2321
|
+
if (n % 64 == 0) {
|
|
2322
|
+
n_aligned = n;
|
|
2323
|
+
} else if (n == 4) {
|
|
2324
|
+
n_aligned = 4;
|
|
2325
|
+
} else if (n < 64) {
|
|
2326
|
+
n_aligned = (n / 8) * 8;
|
|
2327
|
+
} else {
|
|
2328
|
+
n_aligned = (n / 64) * 64;
|
|
2329
|
+
}
|
|
2330
|
+
|
|
2331
|
+
if (n_aligned > 0) {
|
|
2332
|
+
if (n_aligned % 64 == 0) nc = 64;
|
|
2333
|
+
else if (n_aligned == n) nc = n;
|
|
2334
|
+
else if (n_aligned % 32 == 0) nc = 32;
|
|
2335
|
+
else if (n_aligned % 24 == 0) nc = 24;
|
|
2336
|
+
else if (n_aligned % 16 == 0) nc = 16;
|
|
2337
|
+
else nc = 8;
|
|
2338
|
+
}
|
|
2339
|
+
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
|
|
2340
|
+
if (can_use_tiled) {
|
|
2341
|
+
matmul_tiled(m, n_aligned, mc, nc, kc);
|
|
2342
|
+
if (n > n_aligned) {
|
|
2343
|
+
mnpack(0, m, n_aligned, n);
|
|
2344
|
+
}
|
|
2345
|
+
} else {
|
|
2346
|
+
mnpack(0, m, 0, n);
|
|
2347
|
+
}
|
|
1589
2348
|
}
|
|
1590
2349
|
|
|
1591
2350
|
private:
|
|
2351
|
+
inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
|
|
2352
|
+
for (int I = 0; I < RM; I++) {
|
|
2353
|
+
for (int J = 0; J < RN; J++) {
|
|
2354
|
+
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
|
|
2355
|
+
}
|
|
2356
|
+
}
|
|
2357
|
+
}
|
|
1592
2358
|
|
|
1593
|
-
inline void
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
2359
|
+
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
|
2360
|
+
vec_t vec_C[4];
|
|
2361
|
+
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
2362
|
+
for (int I = 0; I < 4; I++) {
|
|
2363
|
+
for (int J = 0; J < 4; J++) {
|
|
2364
|
+
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
|
|
2365
|
+
}
|
|
2366
|
+
}
|
|
1599
2367
|
}
|
|
1600
2368
|
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
|
|
1611
|
-
}
|
|
2369
|
+
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
|
2370
|
+
vec_t vec_C[4];
|
|
2371
|
+
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
2372
|
+
for (int I = 0; I < 4; I++) {
|
|
2373
|
+
for (int J = 0; J < 4; J++) {
|
|
2374
|
+
float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
|
|
2375
|
+
*c_ptr += *((float *)&vec_C[I] + J);
|
|
2376
|
+
}
|
|
2377
|
+
}
|
|
1612
2378
|
}
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
2379
|
+
|
|
2380
|
+
template<typename ArrayType>
|
|
2381
|
+
inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
|
|
2382
|
+
vector signed int vec_C[4];
|
|
2383
|
+
vector float CA[4] = {0};
|
|
2384
|
+
vector float res[4] = {0};
|
|
2385
|
+
__builtin_mma_disassemble_acc(vec_C, ACC);
|
|
2386
|
+
for (int i = 0; i < 4; i++) {
|
|
2387
|
+
CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
|
|
2388
|
+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
|
2389
|
+
fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
|
|
2390
|
+
}
|
|
2391
|
+
}
|
|
2392
|
+
|
|
2393
|
+
inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
|
|
1618
2394
|
const vector signed char lowMask = vec_splats((signed char)0xF);
|
|
1619
2395
|
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
|
1620
2396
|
const vector signed char v8 = vec_splats((signed char)0x8);
|
|
@@ -1631,7 +2407,7 @@ class tinyBLAS_Q0_PPC {
|
|
|
1631
2407
|
}
|
|
1632
2408
|
|
|
1633
2409
|
template <typename V1, typename V2>
|
|
1634
|
-
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
|
|
2410
|
+
inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
|
|
1635
2411
|
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
1636
2412
|
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
1637
2413
|
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
|
@@ -1655,21 +2431,170 @@ class tinyBLAS_Q0_PPC {
|
|
|
1655
2431
|
t8 = vec_xor(t8, xor_vector);
|
|
1656
2432
|
}
|
|
1657
2433
|
vec_xst(t5, 0, vecOffset);
|
|
1658
|
-
vec_xst(t6, 0, vecOffset+16);
|
|
1659
|
-
vec_xst(t7, 0, vecOffset+32);
|
|
1660
|
-
vec_xst(t8, 0, vecOffset+48);
|
|
2434
|
+
vec_xst(t6, 0, vecOffset + 16);
|
|
2435
|
+
vec_xst(t7, 0, vecOffset + 32);
|
|
2436
|
+
vec_xst(t8, 0, vecOffset + 48);
|
|
2437
|
+
}
|
|
2438
|
+
|
|
2439
|
+
inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
|
|
2440
|
+
const vector signed char lowMask = vec_splats((signed char)0x0F);
|
|
2441
|
+
const vector signed char v8 = vec_splats((signed char)0x08);
|
|
2442
|
+
const vector unsigned char v4 = vec_splats((unsigned char)4);
|
|
2443
|
+
lo = vec_and(packed, lowMask);
|
|
2444
|
+
hi = vec_sr(packed, v4);
|
|
2445
|
+
lo = vec_sub(lo, v8);
|
|
2446
|
+
hi = vec_sub(hi, v8);
|
|
2447
|
+
}
|
|
2448
|
+
|
|
2449
|
+
inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
|
|
2450
|
+
vec_t t[8], s[8];
|
|
2451
|
+
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
|
2452
|
+
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
|
2453
|
+
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
2454
|
+
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
2455
|
+
for (int i = 0; i < 4; i += 2) {
|
|
2456
|
+
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
|
2457
|
+
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
|
2458
|
+
}
|
|
2459
|
+
for (int i = 4; i < 8; i += 2) {
|
|
2460
|
+
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
|
2461
|
+
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
|
2462
|
+
}
|
|
2463
|
+
s[0] = vec_perm(t[0], t[2], swiz3);
|
|
2464
|
+
s[1] = vec_perm(t[0], t[2], swiz4);
|
|
2465
|
+
s[2] = vec_perm(t[1], t[3], swiz3);
|
|
2466
|
+
s[3] = vec_perm(t[1], t[3], swiz4);
|
|
2467
|
+
s[4] = vec_perm(t[4], t[6], swiz3);
|
|
2468
|
+
s[5] = vec_perm(t[4], t[6], swiz4);
|
|
2469
|
+
s[6] = vec_perm(t[5], t[7], swiz3);
|
|
2470
|
+
s[7] = vec_perm(t[5], t[7], swiz4);
|
|
2471
|
+
for (int i = 0; i < 8; ++i) {
|
|
2472
|
+
vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
|
|
2473
|
+
}
|
|
2474
|
+
}
|
|
2475
|
+
|
|
2476
|
+
static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
|
|
2477
|
+
vector signed short i16_hi = vec_unpackh(raw);
|
|
2478
|
+
vector signed short i16_lo = vec_unpackl(raw);
|
|
2479
|
+
|
|
2480
|
+
vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
|
|
2481
|
+
vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
|
|
2482
|
+
vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
|
|
2483
|
+
vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
|
|
2484
|
+
out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
|
|
2485
|
+
out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
|
|
2486
|
+
}
|
|
2487
|
+
|
|
2488
|
+
void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
|
2489
|
+
unsigned char * vecOffset = vec;
|
|
2490
|
+
for (int i = 0; i < rows; i += 8) {
|
|
2491
|
+
const block_q4_0 * rows_base[8];
|
|
2492
|
+
for (int r = 0; r < 8; r++) {
|
|
2493
|
+
rows_base[r] = a + (i + r) * lda;
|
|
2494
|
+
}
|
|
2495
|
+
for (int blk = 0; blk < blocks; blk++) {
|
|
2496
|
+
vector unsigned short hp_res[8][4];
|
|
2497
|
+
for (int r = 0; r < 8; r++) {
|
|
2498
|
+
const block_q4_0 * current_blk = rows_base[r] + blk;
|
|
2499
|
+
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
|
|
2500
|
+
vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs);
|
|
2501
|
+
vector signed char c1, c2;
|
|
2502
|
+
unpack_q4_to_q8(v_qs, c1, c2);
|
|
2503
|
+
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
|
|
2504
|
+
convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
|
|
2505
|
+
}
|
|
2506
|
+
for (int c = 0; c < 4; c++) {
|
|
2507
|
+
vector unsigned char c_arr[8];
|
|
2508
|
+
for (int r = 0; r < 8; r++) {
|
|
2509
|
+
c_arr[r] = (vector unsigned char)hp_res[r][c];
|
|
2510
|
+
}
|
|
2511
|
+
vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
|
|
2512
|
+
vecOffset += 128;
|
|
2513
|
+
}
|
|
2514
|
+
}
|
|
2515
|
+
}
|
|
2516
|
+
}
|
|
2517
|
+
|
|
2518
|
+
template <int chunk_size>
|
|
2519
|
+
static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
|
2520
|
+
unsigned char * vecOffset = vec;
|
|
2521
|
+
const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
|
2522
|
+
const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
|
2523
|
+
const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
2524
|
+
const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
2525
|
+
|
|
2526
|
+
for (int i = 0; i < rows; i += chunk_size) {
|
|
2527
|
+
const block_q8_0 * rows_base[chunk_size];
|
|
2528
|
+
for (int r = 0; r < chunk_size; r++) {
|
|
2529
|
+
rows_base[r] = a + (i + r) * lda;
|
|
2530
|
+
}
|
|
2531
|
+
for (int blk = 0; blk < blocks; blk++) {
|
|
2532
|
+
vector unsigned short hp_res[chunk_size][4];
|
|
2533
|
+
for (int r = 0; r < chunk_size; r++) {
|
|
2534
|
+
const block_q8_0 * b = rows_base[r] + blk;
|
|
2535
|
+
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
|
|
2536
|
+
vector signed char c[2];
|
|
2537
|
+
__vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
|
|
2538
|
+
__builtin_vsx_disassemble_pair(c, & pair);
|
|
2539
|
+
convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
|
|
2540
|
+
convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
|
|
2541
|
+
}
|
|
2542
|
+
for (int col = 0; col < 4; col++) {
|
|
2543
|
+
if constexpr (chunk_size == 8) {
|
|
2544
|
+
vec_t t[8];
|
|
2545
|
+
t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
|
2546
|
+
t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
|
2547
|
+
t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
|
2548
|
+
t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
|
2549
|
+
t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
|
|
2550
|
+
t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
|
|
2551
|
+
t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
|
|
2552
|
+
t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
|
|
2553
|
+
|
|
2554
|
+
vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
|
|
2555
|
+
vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
|
|
2556
|
+
vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
|
|
2557
|
+
vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
|
|
2558
|
+
vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
|
|
2559
|
+
vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
|
|
2560
|
+
vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
|
|
2561
|
+
vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
|
|
2562
|
+
vecOffset += 128;
|
|
2563
|
+
} else {
|
|
2564
|
+
vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
|
2565
|
+
vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
|
2566
|
+
vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
|
2567
|
+
vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
|
2568
|
+
|
|
2569
|
+
vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
|
|
2570
|
+
vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
|
|
2571
|
+
vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
|
|
2572
|
+
vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
|
|
2573
|
+
vecOffset += 64;
|
|
2574
|
+
}
|
|
2575
|
+
}
|
|
2576
|
+
}
|
|
2577
|
+
}
|
|
2578
|
+
}
|
|
2579
|
+
|
|
2580
|
+
void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
|
2581
|
+
if (rows == 4) {
|
|
2582
|
+
pack_q8_block<4>(a, lda, rows, blocks, vec);
|
|
2583
|
+
} else {
|
|
2584
|
+
pack_q8_block<8>(a, lda, rows, blocks, vec);
|
|
2585
|
+
}
|
|
1661
2586
|
}
|
|
1662
2587
|
|
|
1663
2588
|
template<int size>
|
|
1664
|
-
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size
|
|
2589
|
+
void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
|
|
1665
2590
|
int64_t i, j;
|
|
1666
|
-
TA *aoffset = NULL;
|
|
1667
|
-
int8_t *vecOffset = NULL;
|
|
1668
|
-
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
|
1669
|
-
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
|
2591
|
+
TA * aoffset = NULL;
|
|
2592
|
+
int8_t * vecOffset = NULL;
|
|
2593
|
+
TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
|
|
2594
|
+
TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
|
|
1670
2595
|
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
|
1671
2596
|
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
|
1672
|
-
aoffset = const_cast<TA*>(a);
|
|
2597
|
+
aoffset = const_cast<TA *>(a);
|
|
1673
2598
|
vecOffset = vec;
|
|
1674
2599
|
j = (rows >> 3);
|
|
1675
2600
|
if (j > 0) {
|
|
@@ -1686,27 +2611,27 @@ class tinyBLAS_Q0_PPC {
|
|
|
1686
2611
|
i = (cols >> 2);
|
|
1687
2612
|
if (i > 0) {
|
|
1688
2613
|
do {
|
|
1689
|
-
c1[1] =
|
|
1690
|
-
c2[1] =
|
|
1691
|
-
c3[1] =
|
|
1692
|
-
c4[1] =
|
|
1693
|
-
c5[1] =
|
|
1694
|
-
c6[1] =
|
|
1695
|
-
c7[1] =
|
|
1696
|
-
c8[1] =
|
|
1697
|
-
|
|
1698
|
-
process_q4_elements(c1, &comparray[0]);
|
|
1699
|
-
process_q4_elements(c2, &comparray[1]);
|
|
1700
|
-
process_q4_elements(c3, &comparray[2]);
|
|
1701
|
-
process_q4_elements(c4, &comparray[3]);
|
|
1702
|
-
process_q4_elements(c5, &comparray[4]);
|
|
1703
|
-
process_q4_elements(c6, &comparray[5]);
|
|
1704
|
-
process_q4_elements(c7, &comparray[6]);
|
|
1705
|
-
process_q4_elements(c8, &comparray[7]);
|
|
2614
|
+
c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
|
2615
|
+
c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
|
2616
|
+
c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
|
2617
|
+
c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
|
|
2618
|
+
c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs);
|
|
2619
|
+
c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs);
|
|
2620
|
+
c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs);
|
|
2621
|
+
c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs);
|
|
2622
|
+
|
|
2623
|
+
process_q4_elements(c1, & comparray[0]);
|
|
2624
|
+
process_q4_elements(c2, & comparray[1]);
|
|
2625
|
+
process_q4_elements(c3, & comparray[2]);
|
|
2626
|
+
process_q4_elements(c4, & comparray[3]);
|
|
2627
|
+
process_q4_elements(c5, & comparray[4]);
|
|
2628
|
+
process_q4_elements(c6, & comparray[5]);
|
|
2629
|
+
process_q4_elements(c7, & comparray[6]);
|
|
2630
|
+
process_q4_elements(c8, & comparray[7]);
|
|
1706
2631
|
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
|
1707
|
-
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
|
1708
|
-
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
|
1709
|
-
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
|
2632
|
+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
|
2633
|
+
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
|
|
2634
|
+
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
|
|
1710
2635
|
aoffset1 += lda;
|
|
1711
2636
|
aoffset2 += lda;
|
|
1712
2637
|
aoffset3 += lda;
|
|
@@ -1732,17 +2657,17 @@ class tinyBLAS_Q0_PPC {
|
|
|
1732
2657
|
i = (cols >> 2);
|
|
1733
2658
|
if (i > 0) {
|
|
1734
2659
|
do {
|
|
1735
|
-
c1[1] =
|
|
1736
|
-
c2[1] =
|
|
1737
|
-
c3[1] =
|
|
1738
|
-
c4[1] =
|
|
1739
|
-
|
|
1740
|
-
process_q4_elements(c1, &comparray[0]);
|
|
1741
|
-
process_q4_elements(c2, &comparray[1]);
|
|
1742
|
-
process_q4_elements(c3, &comparray[2]);
|
|
1743
|
-
process_q4_elements(c4, &comparray[3]);
|
|
2660
|
+
c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
|
2661
|
+
c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
|
2662
|
+
c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
|
2663
|
+
c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
|
|
2664
|
+
|
|
2665
|
+
process_q4_elements(c1, & comparray[0]);
|
|
2666
|
+
process_q4_elements(c2, & comparray[1]);
|
|
2667
|
+
process_q4_elements(c3, & comparray[2]);
|
|
2668
|
+
process_q4_elements(c4, & comparray[3]);
|
|
1744
2669
|
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
|
1745
|
-
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
|
2670
|
+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
|
1746
2671
|
aoffset1 += lda;
|
|
1747
2672
|
aoffset2 += lda;
|
|
1748
2673
|
aoffset3 += lda;
|
|
@@ -1761,17 +2686,17 @@ class tinyBLAS_Q0_PPC {
|
|
|
1761
2686
|
if (i > 0) {
|
|
1762
2687
|
do {
|
|
1763
2688
|
switch(rows) {
|
|
1764
|
-
case 3: c3[1] =
|
|
1765
|
-
case 2: c2[1] =
|
|
1766
|
-
case 1: c1[1] =
|
|
2689
|
+
case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
|
|
2690
|
+
case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
|
|
2691
|
+
case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
|
|
1767
2692
|
break;
|
|
1768
2693
|
}
|
|
1769
|
-
process_q4_elements(c1, &comparray[0]);
|
|
1770
|
-
process_q4_elements(c2, &comparray[1]);
|
|
1771
|
-
process_q4_elements(c3, &comparray[2]);
|
|
1772
|
-
process_q4_elements(c4, &comparray[3]);
|
|
2694
|
+
process_q4_elements(c1, & comparray[0]);
|
|
2695
|
+
process_q4_elements(c2, & comparray[1]);
|
|
2696
|
+
process_q4_elements(c3, & comparray[2]);
|
|
2697
|
+
process_q4_elements(c4, & comparray[3]);
|
|
1773
2698
|
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
|
1774
|
-
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
|
2699
|
+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
|
1775
2700
|
aoffset1 += lda;
|
|
1776
2701
|
aoffset2 += lda;
|
|
1777
2702
|
aoffset3 += lda;
|
|
@@ -1781,38 +2706,39 @@ class tinyBLAS_Q0_PPC {
|
|
|
1781
2706
|
}
|
|
1782
2707
|
}
|
|
1783
2708
|
}
|
|
2709
|
+
|
|
1784
2710
|
template<typename VA, typename VB>
|
|
1785
|
-
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
|
2711
|
+
void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
|
|
1786
2712
|
int64_t i, j;
|
|
1787
|
-
block_q8_0 *aoffset = NULL;
|
|
1788
|
-
VA *vecOffset = NULL;
|
|
1789
|
-
block_q8_0* aoffsets[8];
|
|
2713
|
+
block_q8_0 * aoffset = NULL;
|
|
2714
|
+
VA * vecOffset = NULL;
|
|
2715
|
+
block_q8_0 * aoffsets[8];
|
|
1790
2716
|
__vector_pair arr[8];
|
|
1791
2717
|
VB c[8][2] = {0};
|
|
1792
2718
|
VB c1[8] = {0}; VB c2[8] = {0};
|
|
1793
|
-
aoffset = const_cast<block_q8_0*>(a);
|
|
2719
|
+
aoffset = const_cast<block_q8_0 *>(a);
|
|
1794
2720
|
vecOffset = vec;
|
|
1795
2721
|
j = (rows >> 3);
|
|
1796
2722
|
if (j > 0) {
|
|
1797
2723
|
do {
|
|
1798
2724
|
aoffsets[0] = aoffset;
|
|
1799
2725
|
for (int it = 1; it < 8; it++)
|
|
1800
|
-
aoffsets[it] = aoffsets[it-1] + lda;
|
|
2726
|
+
aoffsets[it] = aoffsets[it - 1] + lda;
|
|
1801
2727
|
aoffset += 8 * lda;
|
|
1802
2728
|
|
|
1803
2729
|
i = (cols >> 3);
|
|
1804
2730
|
if (i > 0) {
|
|
1805
2731
|
do {
|
|
1806
2732
|
for (int it = 0; it < 8; it++) {
|
|
1807
|
-
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
|
1808
|
-
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
|
2733
|
+
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
|
2734
|
+
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
|
1809
2735
|
c1[it] = c[it][0];
|
|
1810
2736
|
c2[it] = c[it][1];
|
|
1811
2737
|
}
|
|
1812
2738
|
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
|
1813
|
-
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
|
1814
|
-
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
|
1815
|
-
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
|
2739
|
+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
|
2740
|
+
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
|
|
2741
|
+
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
|
|
1816
2742
|
for (int it = 0; it < 8; it++)
|
|
1817
2743
|
aoffsets[it] += lda;
|
|
1818
2744
|
vecOffset += 256;
|
|
@@ -1822,7 +2748,6 @@ class tinyBLAS_Q0_PPC {
|
|
|
1822
2748
|
j--;
|
|
1823
2749
|
} while(j > 0);
|
|
1824
2750
|
}
|
|
1825
|
-
|
|
1826
2751
|
if (rows & 4) {
|
|
1827
2752
|
aoffsets[0] = aoffset;
|
|
1828
2753
|
for (int it = 1; it < 4; it++ )
|
|
@@ -1832,13 +2757,13 @@ class tinyBLAS_Q0_PPC {
|
|
|
1832
2757
|
if (i > 0) {
|
|
1833
2758
|
do {
|
|
1834
2759
|
for (int it = 0; it < 4; it++) {
|
|
1835
|
-
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
|
1836
|
-
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
|
2760
|
+
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
|
2761
|
+
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
|
1837
2762
|
c1[it] = c[it][0];
|
|
1838
2763
|
c2[it] = c[it][1];
|
|
1839
2764
|
}
|
|
1840
2765
|
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
|
1841
|
-
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
|
2766
|
+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
|
1842
2767
|
for (int it = 0; it < 4; it++) {
|
|
1843
2768
|
aoffsets[it] += lda;
|
|
1844
2769
|
}
|
|
@@ -1851,24 +2776,24 @@ class tinyBLAS_Q0_PPC {
|
|
|
1851
2776
|
if (rows & 3) {
|
|
1852
2777
|
aoffsets[0] = aoffset;
|
|
1853
2778
|
for (int it = 1; it < 3; it++ )
|
|
1854
|
-
aoffsets[it] = aoffsets[it-1] + lda;
|
|
2779
|
+
aoffsets[it] = aoffsets[it - 1] + lda;
|
|
1855
2780
|
i = (cols >> 3);
|
|
1856
2781
|
if (i > 0) {
|
|
1857
2782
|
do {
|
|
1858
2783
|
switch(rows) {
|
|
1859
|
-
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
|
|
1860
|
-
__builtin_vsx_disassemble_pair(c[2], &arr[2]);
|
|
2784
|
+
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
|
|
2785
|
+
__builtin_vsx_disassemble_pair(c[2], & arr[2]);
|
|
1861
2786
|
c1[2] = c[2][0]; c2[2] = c[2][1];
|
|
1862
|
-
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
|
|
1863
|
-
__builtin_vsx_disassemble_pair(c[1], &arr[1]);
|
|
2787
|
+
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
|
|
2788
|
+
__builtin_vsx_disassemble_pair(c[1], & arr[1]);
|
|
1864
2789
|
c1[1] = c[1][0]; c2[1] = c[1][1];
|
|
1865
|
-
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
|
|
1866
|
-
__builtin_vsx_disassemble_pair(c[0], &arr[0]);
|
|
2790
|
+
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
|
|
2791
|
+
__builtin_vsx_disassemble_pair(c[0], & arr[0]);
|
|
1867
2792
|
c1[0] = c[0][0]; c2[0] = c[0][1];
|
|
1868
2793
|
break;
|
|
1869
2794
|
}
|
|
1870
2795
|
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
|
1871
|
-
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
|
2796
|
+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
|
1872
2797
|
for (int it = 0; it < 3; it++)
|
|
1873
2798
|
aoffsets[it] += lda;
|
|
1874
2799
|
vecOffset += 128;
|
|
@@ -1923,26 +2848,26 @@ class tinyBLAS_Q0_PPC {
|
|
|
1923
2848
|
vector float vs[8] = {0};
|
|
1924
2849
|
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
1925
2850
|
for (int l = 0; l < k; l++) {
|
|
1926
|
-
__builtin_mma_xxsetaccz(&acc_0);
|
|
1927
|
-
__builtin_mma_xxsetaccz(&acc_1);
|
|
2851
|
+
__builtin_mma_xxsetaccz(& acc_0);
|
|
2852
|
+
__builtin_mma_xxsetaccz(& acc_1);
|
|
1928
2853
|
if (std::is_same_v<TA, block_q4_0>) {
|
|
1929
|
-
packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
|
2854
|
+
packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
|
|
1930
2855
|
} else {
|
|
1931
|
-
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
|
2856
|
+
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
|
|
1932
2857
|
}
|
|
1933
|
-
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
|
2858
|
+
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
|
1934
2859
|
for(int x = 0; x < 8; x++) {
|
|
1935
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1936
|
-
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
|
|
2860
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
|
2861
|
+
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
|
|
1937
2862
|
}
|
|
1938
2863
|
for (int I = 0; I<4; I++) {
|
|
1939
2864
|
for (int J = 0; J<4; J++) {
|
|
1940
|
-
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
1941
|
-
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
|
2865
|
+
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
|
2866
|
+
*((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
|
1942
2867
|
}
|
|
1943
2868
|
}
|
|
1944
2869
|
if (!isAblock_q4) {
|
|
1945
|
-
auto aoffset = A+(ii*lda)+l;
|
|
2870
|
+
auto aoffset = A + (ii * lda) + l;
|
|
1946
2871
|
for (int i = 0; i < 4; i++) {
|
|
1947
2872
|
comparray[i] = 0;
|
|
1948
2873
|
int ca = 0;
|
|
@@ -1953,11 +2878,11 @@ class tinyBLAS_Q0_PPC {
|
|
|
1953
2878
|
aoffset += lda;
|
|
1954
2879
|
}
|
|
1955
2880
|
}
|
|
1956
|
-
compute
|
|
1957
|
-
compute
|
|
2881
|
+
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
|
2882
|
+
compute(& acc_1, 0, 4, comparray, vs, fin_res);
|
|
1958
2883
|
}
|
|
1959
2884
|
save_res(ii, jj, 0, fin_res);
|
|
1960
|
-
save_res(ii, jj+4, 4, fin_res);
|
|
2885
|
+
save_res(ii, jj + 4, 4, fin_res);
|
|
1961
2886
|
}
|
|
1962
2887
|
|
|
1963
2888
|
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
|
@@ -1968,25 +2893,25 @@ class tinyBLAS_Q0_PPC {
|
|
|
1968
2893
|
vector float vs[8] = {0};
|
|
1969
2894
|
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
1970
2895
|
for (int l = 0; l < k; l++) {
|
|
1971
|
-
__builtin_mma_xxsetaccz(&acc_0);
|
|
1972
|
-
__builtin_mma_xxsetaccz(&acc_1);
|
|
2896
|
+
__builtin_mma_xxsetaccz(& acc_0);
|
|
2897
|
+
__builtin_mma_xxsetaccz(& acc_1);
|
|
1973
2898
|
if (std::is_same_v<TA, block_q4_0>) {
|
|
1974
|
-
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
|
2899
|
+
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
|
1975
2900
|
} else {
|
|
1976
|
-
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
|
2901
|
+
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
|
1977
2902
|
}
|
|
1978
|
-
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
|
2903
|
+
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
|
|
1979
2904
|
for(int x = 0; x < 8; x++) {
|
|
1980
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1981
|
-
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
|
2905
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
|
2906
|
+
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
|
1982
2907
|
}
|
|
1983
|
-
for (int I = 0; I<8; I++) {
|
|
1984
|
-
for (int J = 0; J<4; J++) {
|
|
1985
|
-
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
2908
|
+
for (int I = 0; I < 8; I++) {
|
|
2909
|
+
for (int J = 0; J < 4; J++) {
|
|
2910
|
+
*((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
|
1986
2911
|
}
|
|
1987
2912
|
}
|
|
1988
2913
|
if (!isAblock_q4) {
|
|
1989
|
-
auto aoffset = A+(ii*lda)+l;
|
|
2914
|
+
auto aoffset = A + (ii * lda) + l;
|
|
1990
2915
|
for (int i = 0; i < 8; i++) {
|
|
1991
2916
|
comparray[i] = 0;
|
|
1992
2917
|
int ca = 0;
|
|
@@ -1997,45 +2922,46 @@ class tinyBLAS_Q0_PPC {
|
|
|
1997
2922
|
aoffset += lda;
|
|
1998
2923
|
}
|
|
1999
2924
|
}
|
|
2000
|
-
compute
|
|
2001
|
-
compute
|
|
2925
|
+
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
|
2926
|
+
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
|
2002
2927
|
}
|
|
2003
2928
|
save_res(ii, jj, 0, fin_res);
|
|
2004
|
-
save_res(ii+4, jj, 4, fin_res);
|
|
2929
|
+
save_res(ii + 4, jj, 4, fin_res);
|
|
2005
2930
|
}
|
|
2006
2931
|
|
|
2007
2932
|
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
|
2008
2933
|
vec_t vec_A[16], vec_B[16] = {0};
|
|
2009
2934
|
acc_t acc_0, acc_1, acc_2, acc_3;
|
|
2935
|
+
acc_t acc_4, acc_5, acc_6, acc_7;
|
|
2010
2936
|
std::array<int, 8> comparray {};
|
|
2011
2937
|
vector float fin_res[16] = {0};
|
|
2012
2938
|
vector float vs[16] = {0};
|
|
2013
2939
|
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
2014
2940
|
for (int l = 0; l < k; l++) {
|
|
2015
|
-
__builtin_mma_xxsetaccz(&acc_0);
|
|
2016
|
-
__builtin_mma_xxsetaccz(&acc_1);
|
|
2017
|
-
__builtin_mma_xxsetaccz(&acc_2);
|
|
2018
|
-
__builtin_mma_xxsetaccz(&acc_3);
|
|
2941
|
+
__builtin_mma_xxsetaccz(& acc_0);
|
|
2942
|
+
__builtin_mma_xxsetaccz(& acc_1);
|
|
2943
|
+
__builtin_mma_xxsetaccz(& acc_2);
|
|
2944
|
+
__builtin_mma_xxsetaccz(& acc_3);
|
|
2019
2945
|
if (std::is_same_v<TA, block_q4_0>) {
|
|
2020
|
-
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
|
2946
|
+
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
|
2021
2947
|
} else {
|
|
2022
|
-
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
|
2948
|
+
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
|
2023
2949
|
}
|
|
2024
|
-
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
|
2950
|
+
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
|
2025
2951
|
for(int x = 0; x < 8; x++) {
|
|
2026
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
2027
|
-
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
|
2028
|
-
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
|
|
2029
|
-
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
|
|
2952
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
|
2953
|
+
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
|
2954
|
+
__builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
|
|
2955
|
+
__builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
|
|
2030
2956
|
}
|
|
2031
|
-
for (int I = 0; I<8; I++) {
|
|
2032
|
-
for (int J = 0; J<4; J++) {
|
|
2033
|
-
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
2034
|
-
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
|
2957
|
+
for (int I = 0; I < 8 ; I++) {
|
|
2958
|
+
for (int J = 0; J < 4; J++) {
|
|
2959
|
+
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
|
2960
|
+
*((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
|
2035
2961
|
}
|
|
2036
2962
|
}
|
|
2037
2963
|
if (!isAblock_q4) {
|
|
2038
|
-
auto aoffset = A+(ii*lda)+l;
|
|
2964
|
+
auto aoffset = A + (ii * lda) + l;
|
|
2039
2965
|
for (int i = 0; i < 8; i++) {
|
|
2040
2966
|
comparray[i] = 0;
|
|
2041
2967
|
int ca = 0;
|
|
@@ -2046,15 +2972,96 @@ class tinyBLAS_Q0_PPC {
|
|
|
2046
2972
|
aoffset += lda;
|
|
2047
2973
|
}
|
|
2048
2974
|
}
|
|
2049
|
-
compute
|
|
2050
|
-
compute
|
|
2051
|
-
compute
|
|
2052
|
-
compute
|
|
2975
|
+
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
|
2976
|
+
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
|
2977
|
+
compute(& acc_2, 0, 8, comparray, vs, fin_res);
|
|
2978
|
+
compute(& acc_3, 4, 12, comparray, vs, fin_res);
|
|
2053
2979
|
}
|
|
2054
2980
|
save_res(ii, jj, 0, fin_res);
|
|
2055
|
-
save_res(ii+4, jj, 4, fin_res);
|
|
2056
|
-
save_res(ii, jj+4, 8, fin_res);
|
|
2057
|
-
save_res(ii+4, jj+4, 12, fin_res);
|
|
2981
|
+
save_res(ii + 4, jj, 4, fin_res);
|
|
2982
|
+
save_res(ii, jj + 4, 8, fin_res);
|
|
2983
|
+
save_res(ii + 4, jj + 4, 12, fin_res);
|
|
2984
|
+
}
|
|
2985
|
+
|
|
2986
|
+
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
|
|
2987
|
+
acc_t acc[8];
|
|
2988
|
+
for (int i = 0; i < mc ; i += 16) {
|
|
2989
|
+
for (int j = 0; j < nc; j += 8) {
|
|
2990
|
+
int A0_base = (i / 16) * (2 * 32 * kc);
|
|
2991
|
+
int B0_base = (j / 8) * (32 * kc);
|
|
2992
|
+
for (int x = 0; x < 8; x++) {
|
|
2993
|
+
__builtin_mma_xxsetaccz(&acc[x]);
|
|
2994
|
+
}
|
|
2995
|
+
for (int64_t kk = 0; kk < kc; kk++) {
|
|
2996
|
+
int A0_block_idx = A0_base + kk * 32;
|
|
2997
|
+
int B0_block_idx = B0_base + kk * 32;
|
|
2998
|
+
int A1_block_idx = A0_block_idx + 32 * kc;
|
|
2999
|
+
int B1_block_idx = B0_block_idx + 32 * kc;
|
|
3000
|
+
vec_t * A0_block = & vec_A[A0_block_idx];
|
|
3001
|
+
vec_t * B0_block = & vec_B[B0_block_idx];
|
|
3002
|
+
vec_t * A1_block = & vec_A[A1_block_idx];
|
|
3003
|
+
for (int it = 0; it < 4; it++) {
|
|
3004
|
+
for (int x = 0; x < 4; x++) {
|
|
3005
|
+
__builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
|
|
3006
|
+
__builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
|
|
3007
|
+
__builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
|
|
3008
|
+
__builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
|
3009
|
+
__builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
|
|
3010
|
+
__builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
|
|
3011
|
+
__builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
|
|
3012
|
+
__builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
|
3013
|
+
}
|
|
3014
|
+
}
|
|
3015
|
+
}
|
|
3016
|
+
if (l == 0) {
|
|
3017
|
+
save_acc(& acc[0], ii + i, jj + j);
|
|
3018
|
+
save_acc(& acc[1], ii + i, jj + j + 4);
|
|
3019
|
+
save_acc(& acc[2], ii + i + 4, jj + j);
|
|
3020
|
+
save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
|
3021
|
+
save_acc(& acc[4], ii + i + 8, jj + j);
|
|
3022
|
+
save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
|
3023
|
+
save_acc(& acc[6], ii + i + 12, jj + j);
|
|
3024
|
+
save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
|
3025
|
+
} else {
|
|
3026
|
+
add_save_acc(& acc[0], ii + i, jj + j);
|
|
3027
|
+
add_save_acc(& acc[1], ii + i, jj + j + 4);
|
|
3028
|
+
add_save_acc(& acc[2], ii + i + 4, jj + j);
|
|
3029
|
+
add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
|
3030
|
+
add_save_acc(& acc[4], ii + i + 8, jj + j);
|
|
3031
|
+
add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
|
3032
|
+
add_save_acc(& acc[6], ii + i + 12, jj + j);
|
|
3033
|
+
add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
|
3034
|
+
}
|
|
3035
|
+
}
|
|
3036
|
+
}
|
|
3037
|
+
}
|
|
3038
|
+
|
|
3039
|
+
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
|
3040
|
+
vec_t A_pack[mc * kc * 4];
|
|
3041
|
+
vec_t B_pack[nc * kc * 4];
|
|
3042
|
+
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
3043
|
+
int64_t ytiles = m / mc;
|
|
3044
|
+
int64_t xtiles = n / nc;
|
|
3045
|
+
int64_t tiles = xtiles * ytiles;
|
|
3046
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
3047
|
+
int64_t start = duty * ith;
|
|
3048
|
+
int64_t end = start + duty;
|
|
3049
|
+
if (end > tiles) {
|
|
3050
|
+
end = tiles;
|
|
3051
|
+
}
|
|
3052
|
+
for (int64_t job = start; job < end; ++job) {
|
|
3053
|
+
int64_t ii = (job / xtiles) * mc;
|
|
3054
|
+
int64_t jj = (job % xtiles) * nc;
|
|
3055
|
+
for (int64_t kk = 0; kk < k; kk += kc) {
|
|
3056
|
+
if constexpr(is_Ablock_q4) {
|
|
3057
|
+
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
|
3058
|
+
} else {
|
|
3059
|
+
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
|
3060
|
+
}
|
|
3061
|
+
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
|
|
3062
|
+
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
|
|
3063
|
+
}
|
|
3064
|
+
}
|
|
2058
3065
|
}
|
|
2059
3066
|
|
|
2060
3067
|
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
|
@@ -2079,32 +3086,32 @@ class tinyBLAS_Q0_PPC {
|
|
|
2079
3086
|
vector float fin_res[4] = {0};
|
|
2080
3087
|
vector float vs[4] = {0};
|
|
2081
3088
|
vector float CA[4] = {0};
|
|
2082
|
-
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
|
|
2083
|
-
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
|
|
3089
|
+
__builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
|
|
3090
|
+
__builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
|
|
2084
3091
|
for (int l = 0; l < k; l++) {
|
|
2085
|
-
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
|
2086
|
-
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
|
2087
|
-
__builtin_mma_xxsetaccz(&acc_0);
|
|
3092
|
+
__builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
|
3093
|
+
__builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
|
3094
|
+
__builtin_mma_xxsetaccz(& acc_0);
|
|
2088
3095
|
if (isAblock_q4) {
|
|
2089
|
-
|
|
3096
|
+
packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
|
|
2090
3097
|
} else {
|
|
2091
|
-
|
|
3098
|
+
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
|
|
2092
3099
|
}
|
|
2093
|
-
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
|
2094
|
-
for(int x = 0; x < 8; x+=4) {
|
|
2095
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
2096
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
|
|
2097
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
|
|
2098
|
-
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
|
|
3100
|
+
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
|
|
3101
|
+
for (int x = 0; x < 8; x += 4) {
|
|
3102
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
|
3103
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
|
|
3104
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
|
|
3105
|
+
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
|
|
2099
3106
|
}
|
|
2100
|
-
for (int I = 0; I<RM; I++) {
|
|
2101
|
-
for (int J = 0; J<RN; J++) {
|
|
2102
|
-
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
3107
|
+
for (int I = 0; I < RM; I++) {
|
|
3108
|
+
for (int J = 0; J < RN; J++) {
|
|
3109
|
+
*((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
|
2103
3110
|
}
|
|
2104
3111
|
}
|
|
2105
|
-
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
3112
|
+
__builtin_mma_disassemble_acc(vec_C, & acc_0);
|
|
2106
3113
|
if (!isAblock_q4) {
|
|
2107
|
-
auto aoffset = A+(ii*lda)+l;
|
|
3114
|
+
auto aoffset = A + (ii * lda) + l;
|
|
2108
3115
|
for (int i = 0; i < RM; i++) {
|
|
2109
3116
|
comparray[i] = 0;
|
|
2110
3117
|
int ca = 0;
|
|
@@ -2127,15 +3134,15 @@ class tinyBLAS_Q0_PPC {
|
|
|
2127
3134
|
|
|
2128
3135
|
template<int RM, int RN>
|
|
2129
3136
|
inline void kernel(int64_t ii, int64_t jj) {
|
|
2130
|
-
|
|
2131
|
-
|
|
2132
|
-
|
|
2133
|
-
|
|
2134
|
-
|
|
2135
|
-
|
|
2136
|
-
|
|
2137
|
-
|
|
2138
|
-
|
|
3137
|
+
if constexpr(RM == 4 && RN == 8) {
|
|
3138
|
+
KERNEL_4x8(ii,jj);
|
|
3139
|
+
} else if constexpr(RM == 8 && RN == 4) {
|
|
3140
|
+
KERNEL_8x4(ii,jj);
|
|
3141
|
+
} else if constexpr(RM == 8 && RN == 8) {
|
|
3142
|
+
KERNEL_8x8(ii,jj);
|
|
3143
|
+
} else {
|
|
3144
|
+
assert(false && "RN/RM values not supported");
|
|
3145
|
+
}
|
|
2139
3146
|
}
|
|
2140
3147
|
|
|
2141
3148
|
template <int RM, int RN>
|
|
@@ -2154,11 +3161,11 @@ class tinyBLAS_Q0_PPC {
|
|
|
2154
3161
|
kernel<RM, RN>(ii, jj);
|
|
2155
3162
|
}
|
|
2156
3163
|
}
|
|
2157
|
-
|
|
2158
|
-
const
|
|
2159
|
-
|
|
2160
|
-
float *C;
|
|
3164
|
+
const TA * const A;
|
|
3165
|
+
const block_q8_0 * const B;
|
|
3166
|
+
float * C;
|
|
2161
3167
|
const int64_t k;
|
|
3168
|
+
int64_t kc;
|
|
2162
3169
|
const int64_t lda;
|
|
2163
3170
|
const int64_t ldb;
|
|
2164
3171
|
const int64_t ldc;
|
|
@@ -2731,6 +3738,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2731
3738
|
params->ith, params->nth};
|
|
2732
3739
|
tb.matmul(m, n);
|
|
2733
3740
|
return true;
|
|
3741
|
+
#elif defined(__riscv_zvfh)
|
|
3742
|
+
#if LMUL == 1
|
|
3743
|
+
tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
|
|
3744
|
+
k, (const float *)A, lda,
|
|
3745
|
+
(const float *)B, ldb,
|
|
3746
|
+
(float *)C, ldc};
|
|
3747
|
+
#elif LMUL == 2
|
|
3748
|
+
tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
|
|
3749
|
+
k, (const float *)A, lda,
|
|
3750
|
+
(const float *)B, ldb,
|
|
3751
|
+
(float *)C, ldc};
|
|
3752
|
+
#else // LMUL = 4
|
|
3753
|
+
tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
|
|
3754
|
+
k, (const float *)A, lda,
|
|
3755
|
+
(const float *)B, ldb,
|
|
3756
|
+
(float *)C, ldc};
|
|
3757
|
+
#endif
|
|
3758
|
+
return tb.matmul(m, n);
|
|
2734
3759
|
#else
|
|
2735
3760
|
return false;
|
|
2736
3761
|
#endif
|
|
@@ -2762,17 +3787,38 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2762
3787
|
return tb.matmul(m, n);
|
|
2763
3788
|
}
|
|
2764
3789
|
#elif defined(__MMA__)
|
|
2765
|
-
if (
|
|
2766
|
-
|
|
2767
|
-
|
|
2768
|
-
|
|
2769
|
-
|
|
2770
|
-
|
|
2771
|
-
|
|
2772
|
-
|
|
2773
|
-
|
|
2774
|
-
|
|
3790
|
+
if (k % 8) {
|
|
3791
|
+
return false;
|
|
3792
|
+
}
|
|
3793
|
+
|
|
3794
|
+
if (Btype == GGML_TYPE_BF16) {
|
|
3795
|
+
tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
|
3796
|
+
(const ggml_bf16_t *)A, lda,
|
|
3797
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3798
|
+
(float *)C, ldc,
|
|
3799
|
+
params->ith, params->nth };
|
|
3800
|
+
|
|
3801
|
+
tb.matmul(m, n);
|
|
3802
|
+
return true;
|
|
2775
3803
|
}
|
|
3804
|
+
#elif defined(__riscv_zvfbfwma)
|
|
3805
|
+
#if LMUL == 1
|
|
3806
|
+
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3807
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3808
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3809
|
+
(float *)C, ldc};
|
|
3810
|
+
#elif LMUL == 2
|
|
3811
|
+
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3812
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3813
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3814
|
+
(float *)C, ldc};
|
|
3815
|
+
#else // LMUL = 4
|
|
3816
|
+
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
|
3817
|
+
k, (const ggml_bf16_t *)A, lda,
|
|
3818
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3819
|
+
(float *)C, ldc};
|
|
3820
|
+
#endif
|
|
3821
|
+
return tb.matmul(m, n);
|
|
2776
3822
|
#endif
|
|
2777
3823
|
return false;
|
|
2778
3824
|
}
|
|
@@ -2822,6 +3868,41 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2822
3868
|
(float *)C, ldc};
|
|
2823
3869
|
return tb.matmul(m, n);
|
|
2824
3870
|
}
|
|
3871
|
+
#elif defined(__riscv_zvfh)
|
|
3872
|
+
if (Btype == GGML_TYPE_F16) {
|
|
3873
|
+
#if LMUL == 1
|
|
3874
|
+
tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
3875
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
3876
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3877
|
+
(float *)C, ldc};
|
|
3878
|
+
#elif LMUL == 2
|
|
3879
|
+
tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
3880
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
3881
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3882
|
+
(float *)C, ldc};
|
|
3883
|
+
#else // LMUL = 4
|
|
3884
|
+
tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
|
3885
|
+
k, (const ggml_fp16_t *)A, lda,
|
|
3886
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3887
|
+
(float *)C, ldc};
|
|
3888
|
+
#endif
|
|
3889
|
+
return tb.matmul(m, n);
|
|
3890
|
+
}
|
|
3891
|
+
#elif defined(__MMA__)
|
|
3892
|
+
if (k % 8) {
|
|
3893
|
+
return false;
|
|
3894
|
+
}
|
|
3895
|
+
|
|
3896
|
+
if (Btype == GGML_TYPE_F16) {
|
|
3897
|
+
tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
|
|
3898
|
+
(const ggml_fp16_t *)A, lda,
|
|
3899
|
+
(const ggml_fp16_t *)B, ldb,
|
|
3900
|
+
(float *)C, ldc,
|
|
3901
|
+
params->ith, params->nth };
|
|
3902
|
+
|
|
3903
|
+
tb.matmul(m, n);
|
|
3904
|
+
return true;
|
|
3905
|
+
}
|
|
2825
3906
|
#endif
|
|
2826
3907
|
return false;
|
|
2827
3908
|
}
|