whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -15,18 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
#include <sycl/sycl.hpp>
|
|
17
17
|
#include <sycl/half_type.hpp>
|
|
18
|
-
#include <syclcompat/math.hpp>
|
|
19
|
-
#include <map>
|
|
20
|
-
|
|
21
|
-
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
|
22
18
|
#include <oneapi/mkl.hpp>
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
namespace math = mkl;
|
|
26
|
-
}
|
|
27
|
-
#else
|
|
28
|
-
#include <oneapi/math.hpp>
|
|
29
|
-
#endif
|
|
19
|
+
|
|
20
|
+
#include <map>
|
|
30
21
|
|
|
31
22
|
#include "ggml.h"
|
|
32
23
|
|
|
@@ -92,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
|
|
92
83
|
}
|
|
93
84
|
|
|
94
85
|
template <typename Ts> struct matrix_info_t {
|
|
95
|
-
oneapi::
|
|
86
|
+
oneapi::mkl::transpose transpose_info[2];
|
|
96
87
|
Ts value_info[2];
|
|
97
88
|
std::int64_t size_info[3];
|
|
98
89
|
std::int64_t ld_info[3];
|
|
99
90
|
std::int64_t groupsize_info;
|
|
100
91
|
};
|
|
101
92
|
|
|
102
|
-
inline auto get_onemath_backend(sycl::queue& queue)
|
|
103
|
-
#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
104
|
-
-> sycl::queue&
|
|
105
|
-
#endif
|
|
106
|
-
{
|
|
107
|
-
// If the backend is known at compile-time, use oneMath backend_selector to use
|
|
108
|
-
// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
|
|
109
|
-
// fallback to runtime dispatching.
|
|
110
|
-
#if defined(GGML_SYCL_NVIDIA)
|
|
111
|
-
return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
|
|
112
|
-
#elif defined(GGML_SYCL_AMD)
|
|
113
|
-
return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
|
|
114
|
-
#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
115
|
-
return queue;
|
|
116
|
-
#else
|
|
117
|
-
static_assert(false, "Unsupported backend");
|
|
118
|
-
#endif
|
|
119
|
-
}
|
|
120
|
-
|
|
121
93
|
namespace dpct
|
|
122
94
|
{
|
|
123
95
|
typedef sycl::queue *queue_ptr;
|
|
@@ -277,6 +249,26 @@ namespace dpct
|
|
|
277
249
|
|
|
278
250
|
} // namespace detail
|
|
279
251
|
|
|
252
|
+
// COPY from DPCT head files
|
|
253
|
+
/// dim3 is used to store 3 component dimensions.
|
|
254
|
+
class dim3 {
|
|
255
|
+
public:
|
|
256
|
+
unsigned x, y, z;
|
|
257
|
+
|
|
258
|
+
constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1)
|
|
259
|
+
: x(x), y(y), z(z) {}
|
|
260
|
+
|
|
261
|
+
dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {}
|
|
262
|
+
|
|
263
|
+
operator sycl::range<3>() const { return sycl::range<3>(z, y, x); }
|
|
264
|
+
}; // namespace dim3
|
|
265
|
+
|
|
266
|
+
inline dim3 operator*(const dim3 &a, const dim3 &b) {
|
|
267
|
+
return dim3{a.x * b.x, a.y * b.y, a.z * b.z};
|
|
268
|
+
}
|
|
269
|
+
// COPY from DPCT head files
|
|
270
|
+
|
|
271
|
+
|
|
280
272
|
/// Pitched 2D/3D memory data.
|
|
281
273
|
class pitched_data
|
|
282
274
|
{
|
|
@@ -1715,7 +1707,7 @@ namespace dpct
|
|
|
1715
1707
|
namespace detail
|
|
1716
1708
|
{
|
|
1717
1709
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1718
|
-
inline void gemm_impl(sycl::queue & q, oneapi::
|
|
1710
|
+
inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
|
1719
1711
|
int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
|
|
1720
1712
|
const void * beta, void * c, int ldc) {
|
|
1721
1713
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
@@ -1723,7 +1715,7 @@ namespace dpct
|
|
|
1723
1715
|
auto data_a = get_memory<const Ta>(a);
|
|
1724
1716
|
auto data_b = get_memory<const Tb>(b);
|
|
1725
1717
|
auto data_c = get_memory<Tc>(c);
|
|
1726
|
-
oneapi::
|
|
1718
|
+
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
|
|
1727
1719
|
lda, data_b, ldb, beta_value, data_c, ldc);
|
|
1728
1720
|
}
|
|
1729
1721
|
|
|
@@ -1755,7 +1747,7 @@ namespace dpct
|
|
|
1755
1747
|
};
|
|
1756
1748
|
|
|
1757
1749
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1758
|
-
inline void gemm_batch_impl(sycl::queue & q, oneapi::
|
|
1750
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
|
1759
1751
|
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
|
1760
1752
|
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
|
1761
1753
|
matrix_info_t<float> * matrix_info) {
|
|
@@ -1774,8 +1766,8 @@ namespace dpct
|
|
|
1774
1766
|
matrix_info->ld_info[2] = ldc;
|
|
1775
1767
|
matrix_info->groupsize_info = batch_size;
|
|
1776
1768
|
|
|
1777
|
-
sycl::event e = oneapi::
|
|
1778
|
-
|
|
1769
|
+
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
1770
|
+
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
|
1779
1771
|
matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
|
|
1780
1772
|
reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
|
1781
1773
|
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
@@ -1784,7 +1776,7 @@ namespace dpct
|
|
|
1784
1776
|
}
|
|
1785
1777
|
|
|
1786
1778
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1787
|
-
inline void gemm_batch_impl(sycl::queue & q, oneapi::
|
|
1779
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
|
1788
1780
|
int m, int n, int k, const void * alpha, const void * a, int lda,
|
|
1789
1781
|
long long int stride_a, const void * b, int ldb, long long int stride_b,
|
|
1790
1782
|
const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
|
|
@@ -1793,7 +1785,7 @@ namespace dpct
|
|
|
1793
1785
|
auto data_a = get_memory<const Ta>(a);
|
|
1794
1786
|
auto data_b = get_memory<const Tb>(b);
|
|
1795
1787
|
auto data_c = get_memory<Tc>(c);
|
|
1796
|
-
oneapi::
|
|
1788
|
+
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
|
|
1797
1789
|
data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
|
|
1798
1790
|
data_c, ldc, stride_c, batch_size);
|
|
1799
1791
|
}
|
|
@@ -1840,10 +1832,31 @@ namespace dpct
|
|
|
1840
1832
|
: id);
|
|
1841
1833
|
}
|
|
1842
1834
|
|
|
1835
|
+
template <typename T1, typename T2>
|
|
1836
|
+
using dot_product_acc_t = std::conditional_t<
|
|
1837
|
+
std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
|
|
1838
|
+
uint32_t,
|
|
1839
|
+
int32_t>;
|
|
1840
|
+
|
|
1841
|
+
template <typename T>
|
|
1842
|
+
sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {
|
|
1843
|
+
return sycl::vec<T, 1>(val)
|
|
1844
|
+
.template as<sycl::vec<
|
|
1845
|
+
std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>,
|
|
1846
|
+
4>>()
|
|
1847
|
+
.template convert<T>();
|
|
1848
|
+
}
|
|
1849
|
+
|
|
1843
1850
|
template <typename T1, typename T2, typename T3>
|
|
1844
|
-
inline auto dp4a(T1 a, T2 b, T3 c)
|
|
1845
|
-
|
|
1846
|
-
|
|
1851
|
+
inline auto dp4a(T1 a, T2 b, T3 c) {
|
|
1852
|
+
dot_product_acc_t<T1, T2> res = c;
|
|
1853
|
+
auto va = extract_and_sign_or_zero_extend4(a);
|
|
1854
|
+
auto vb = extract_and_sign_or_zero_extend4(b);
|
|
1855
|
+
res += va[0] * vb[0];
|
|
1856
|
+
res += va[1] * vb[1];
|
|
1857
|
+
res += va[2] * vb[2];
|
|
1858
|
+
res += va[3] * vb[3];
|
|
1859
|
+
return res;
|
|
1847
1860
|
}
|
|
1848
1861
|
|
|
1849
1862
|
struct sub_sat
|
|
@@ -2259,7 +2272,7 @@ namespace dpct
|
|
|
2259
2272
|
sycl::range<3>(x, y, 1), direction);
|
|
2260
2273
|
}
|
|
2261
2274
|
|
|
2262
|
-
inline void gemm(sycl::queue & q, oneapi::
|
|
2275
|
+
inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
|
|
2263
2276
|
int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
|
|
2264
2277
|
library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
2265
2278
|
library_data_t scaling_type) {
|
|
@@ -2326,7 +2339,7 @@ namespace dpct
|
|
|
2326
2339
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2327
2340
|
library_data_t::real_float, library_data_t::real_float):
|
|
2328
2341
|
{
|
|
2329
|
-
detail::gemm_impl<oneapi::
|
|
2342
|
+
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
|
2330
2343
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
2331
2344
|
break;
|
|
2332
2345
|
}
|
|
@@ -2365,7 +2378,7 @@ namespace dpct
|
|
|
2365
2378
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2366
2379
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2367
2380
|
{
|
|
2368
|
-
detail::gemm_impl<oneapi::
|
|
2381
|
+
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
|
2369
2382
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
2370
2383
|
break;
|
|
2371
2384
|
}
|
|
@@ -2407,7 +2420,7 @@ namespace dpct
|
|
|
2407
2420
|
/// \param [in] ldc Leading dimension of C.
|
|
2408
2421
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
2409
2422
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
2410
|
-
inline void gemm_batch(sycl::queue & q, oneapi::
|
|
2423
|
+
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
|
2411
2424
|
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
|
2412
2425
|
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
|
2413
2426
|
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
|
@@ -2445,7 +2458,7 @@ namespace dpct
|
|
|
2445
2458
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2446
2459
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2447
2460
|
{
|
|
2448
|
-
detail::gemm_batch_impl<oneapi::
|
|
2461
|
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
|
2449
2462
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
2450
2463
|
break;
|
|
2451
2464
|
}
|
|
@@ -2453,7 +2466,7 @@ namespace dpct
|
|
|
2453
2466
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2454
2467
|
library_data_t::real_float, library_data_t::real_float):
|
|
2455
2468
|
{
|
|
2456
|
-
detail::gemm_batch_impl<oneapi::
|
|
2469
|
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
|
2457
2470
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
2458
2471
|
break;
|
|
2459
2472
|
}
|
|
@@ -2529,7 +2542,7 @@ namespace dpct
|
|
|
2529
2542
|
/// \param [in] stride_c Stride between the different C matrices.
|
|
2530
2543
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
2531
2544
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
2532
|
-
inline void gemm_batch(sycl::queue & q, oneapi::
|
|
2545
|
+
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
|
2533
2546
|
int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
|
|
2534
2547
|
long long int stride_a, const void * b, library_data_t b_type, int ldb,
|
|
2535
2548
|
long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
@@ -2602,7 +2615,7 @@ namespace dpct
|
|
|
2602
2615
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2603
2616
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2604
2617
|
{
|
|
2605
|
-
detail::gemm_batch_impl<oneapi::
|
|
2618
|
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
|
2606
2619
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
2607
2620
|
batch_size);
|
|
2608
2621
|
break;
|
|
@@ -2611,7 +2624,7 @@ namespace dpct
|
|
|
2611
2624
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2612
2625
|
library_data_t::real_float, library_data_t::real_float):
|
|
2613
2626
|
{
|
|
2614
|
-
detail::gemm_batch_impl<oneapi::
|
|
2627
|
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
|
2615
2628
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
2616
2629
|
batch_size);
|
|
2617
2630
|
break;
|
|
@@ -2952,6 +2965,810 @@ namespace dpct
|
|
|
2952
2965
|
atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
|
|
2953
2966
|
}
|
|
2954
2967
|
|
|
2968
|
+
inline unsigned int byte_level_permute(
|
|
2969
|
+
unsigned int a, unsigned int b, unsigned int s) {
|
|
2970
|
+
unsigned int ret;
|
|
2971
|
+
ret = ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) |
|
|
2972
|
+
(((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff)
|
|
2973
|
+
<< 8) |
|
|
2974
|
+
(((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff)
|
|
2975
|
+
<< 16) |
|
|
2976
|
+
(((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff)
|
|
2977
|
+
<< 24);
|
|
2978
|
+
return ret;
|
|
2979
|
+
}
|
|
2980
|
+
|
|
2981
|
+
inline uint32_t byte_level_permute_custom(
|
|
2982
|
+
uint32_t low32, uint32_t high32, uint32_t sel, int mode = 0) {
|
|
2983
|
+
constexpr uint16_t lookup[6][4] = {
|
|
2984
|
+
{0x3210, 0x4321, 0x5432, 0x6543}, // Forward 4-byte extract
|
|
2985
|
+
{0x5670, 0x6701, 0x7012, 0x0123}, // Backward 4-byte extract
|
|
2986
|
+
{0x0000, 0x1111, 0x2222, 0x3333}, // Replicate 8-bit values
|
|
2987
|
+
{0x3210, 0x3211, 0x3222, 0x3333}, // Edge clamp left
|
|
2988
|
+
{0x0000, 0x1110, 0x2210, 0x3210}, // Edge clamp right
|
|
2989
|
+
{0x1010, 0x3232, 0x1010, 0x3232} // Replicate 16-bit values
|
|
2990
|
+
};
|
|
2991
|
+
|
|
2992
|
+
if (mode >= 1 && mode <= 6) {
|
|
2993
|
+
return byte_level_permute(low32, high32, lookup[mode - 1][sel & 0x3]);
|
|
2994
|
+
} else if (!mode) {
|
|
2995
|
+
return byte_level_permute(low32, high32, sel);
|
|
2996
|
+
}
|
|
2997
|
+
return 0;
|
|
2998
|
+
}
|
|
2999
|
+
|
|
3000
|
+
template <int n_nondefault_params, int n_default_params, typename T>
|
|
3001
|
+
class args_selector;
|
|
3002
|
+
|
|
3003
|
+
/// args_selector is a helper class for extracting arguments from an
|
|
3004
|
+
/// array of pointers to arguments or buffer of arguments to pass to a
|
|
3005
|
+
/// kernel function.
|
|
3006
|
+
///
|
|
3007
|
+
/// \param R(Ts...) The type of the kernel
|
|
3008
|
+
/// \param n_nondefault_params The number of nondefault parameters of the
|
|
3009
|
+
/// kernel (excluding parameters that like sycl::nd_item, etc.) \param
|
|
3010
|
+
/// n_default_params The number of default parameters of the kernel
|
|
3011
|
+
///
|
|
3012
|
+
/// Example usage:
|
|
3013
|
+
/// With the following kernel:
|
|
3014
|
+
/// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
|
|
3015
|
+
/// f=.1) {}
|
|
3016
|
+
/// and with the declaration:
|
|
3017
|
+
/// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
|
|
3018
|
+
/// we have:
|
|
3019
|
+
/// selector.get<0>() returns a reference to sycl::float*,
|
|
3020
|
+
/// selector.get<1>() returns a reference to int,
|
|
3021
|
+
/// selector.get<2>() returns a reference to float
|
|
3022
|
+
template <int n_nondefault_params, int n_default_params, typename R,
|
|
3023
|
+
typename... Ts>
|
|
3024
|
+
class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
|
|
3025
|
+
private:
|
|
3026
|
+
void **kernel_params;
|
|
3027
|
+
char *args_buffer;
|
|
3028
|
+
|
|
3029
|
+
template <int i> static constexpr int account_for_default_params() {
|
|
3030
|
+
constexpr int n_total_params = sizeof...(Ts);
|
|
3031
|
+
if constexpr (i >= n_nondefault_params) {
|
|
3032
|
+
return n_total_params - n_default_params +
|
|
3033
|
+
(i - n_nondefault_params);
|
|
3034
|
+
} else {
|
|
3035
|
+
return i;
|
|
3036
|
+
}
|
|
3037
|
+
}
|
|
3038
|
+
|
|
3039
|
+
public:
|
|
3040
|
+
/// Get the type of the ith argument of R(Ts...)
|
|
3041
|
+
/// \param [in] i Index of parameter to get
|
|
3042
|
+
/// \returns Type of ith parameter
|
|
3043
|
+
template <int i>
|
|
3044
|
+
using arg_type = std::tuple_element_t<account_for_default_params<i>(),
|
|
3045
|
+
std::tuple<Ts...>>;
|
|
3046
|
+
static constexpr int params_num = sizeof...(Ts);
|
|
3047
|
+
|
|
3048
|
+
private:
|
|
3049
|
+
template <int i> static constexpr int get_offset() {
|
|
3050
|
+
if constexpr (i == 0) {
|
|
3051
|
+
// we can assume args_buffer is properly aligned to the
|
|
3052
|
+
// first argument
|
|
3053
|
+
return 0;
|
|
3054
|
+
} else {
|
|
3055
|
+
constexpr int prev_off = get_offset<i - 1>();
|
|
3056
|
+
constexpr int prev_past_end =
|
|
3057
|
+
prev_off + sizeof(arg_type<i - 1>);
|
|
3058
|
+
using T = arg_type<i>;
|
|
3059
|
+
// is the past-the-end of the i-1st element properly aligned
|
|
3060
|
+
// with the ith element's alignment?
|
|
3061
|
+
if constexpr (prev_past_end % alignof(T) == 0) {
|
|
3062
|
+
return prev_past_end;
|
|
3063
|
+
}
|
|
3064
|
+
// otherwise bump prev_past_end to match alignment
|
|
3065
|
+
else {
|
|
3066
|
+
return prev_past_end +
|
|
3067
|
+
(alignof(T) - (prev_past_end % alignof(T)));
|
|
3068
|
+
}
|
|
3069
|
+
}
|
|
3070
|
+
}
|
|
3071
|
+
|
|
3072
|
+
static char *get_args_buffer(void **extra) {
|
|
3073
|
+
if (!extra)
|
|
3074
|
+
return nullptr;
|
|
3075
|
+
for (; (std::size_t)*extra != 0; ++extra) {
|
|
3076
|
+
if ((std::size_t)*extra == 1) {
|
|
3077
|
+
return static_cast<char *>(*(extra + 1));
|
|
3078
|
+
}
|
|
3079
|
+
}
|
|
3080
|
+
return nullptr;
|
|
3081
|
+
}
|
|
3082
|
+
|
|
3083
|
+
public:
|
|
3084
|
+
/// If kernel_params is nonnull, then args_selector will
|
|
3085
|
+
/// extract arguments from kernel_params. Otherwise, it
|
|
3086
|
+
/// will extract them from extra.
|
|
3087
|
+
/// \param [in] kernel_params Array of pointers to arguments
|
|
3088
|
+
/// a or null pointer.
|
|
3089
|
+
/// \param [in] extra Array containing pointer to argument buffer.
|
|
3090
|
+
args_selector(void **kernel_params, void **extra)
|
|
3091
|
+
: kernel_params(kernel_params),
|
|
3092
|
+
args_buffer(get_args_buffer(extra)) {}
|
|
3093
|
+
|
|
3094
|
+
/// Get a reference to the ith argument extracted from kernel_params
|
|
3095
|
+
/// or extra.
|
|
3096
|
+
/// \param [in] i Index of argument to get
|
|
3097
|
+
/// \returns Reference to the ith argument
|
|
3098
|
+
template <int i> arg_type<i> &get() {
|
|
3099
|
+
if (kernel_params) {
|
|
3100
|
+
return *static_cast<arg_type<i> *>(kernel_params[i]);
|
|
3101
|
+
} else {
|
|
3102
|
+
return *reinterpret_cast<arg_type<i> *>(args_buffer +
|
|
3103
|
+
get_offset<i>());
|
|
3104
|
+
}
|
|
3105
|
+
}
|
|
3106
|
+
}; // COPY from DPCT head file
|
|
3107
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
|
|
3108
|
+
|
|
3109
|
+
/// Utility class for launching SYCL kernels through kernel
|
|
3110
|
+
/// function wrapper.
|
|
3111
|
+
/// For example:
|
|
3112
|
+
/// A SYCL kernel function:
|
|
3113
|
+
/// void kernel_func(int *ptr, sycl::nd_item<3> item);
|
|
3114
|
+
/// Kernel function wrapper:
|
|
3115
|
+
/// void kernel_func_wrapper(int *ptr) {
|
|
3116
|
+
/// sycl::queue queue = *dpct::kernel_launcher::_que;
|
|
3117
|
+
/// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
|
|
3118
|
+
/// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
|
|
3119
|
+
/// queue.parallel_for(
|
|
3120
|
+
/// nr,
|
|
3121
|
+
/// [=](sycl::nd_item<3> item_ct1) {
|
|
3122
|
+
/// kernel_func(ptr, item_ct1);
|
|
3123
|
+
/// });
|
|
3124
|
+
/// }
|
|
3125
|
+
/// Then launch the kernel through wrapper like:
|
|
3126
|
+
/// typedef void(*fpt)(int *);
|
|
3127
|
+
/// fpt fp = kernel_func_wrapper;
|
|
3128
|
+
/// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
|
|
3129
|
+
/// device_ptr);
|
|
3130
|
+
/// If the origin function type is erased, then need to register it first:
|
|
3131
|
+
/// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
|
|
3132
|
+
/// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
|
|
3133
|
+
/// 0, 0);
|
|
3134
|
+
class kernel_launcher {
|
|
3135
|
+
template <typename FuncT, typename ArgSelector, std::size_t... Index>
|
|
3136
|
+
static void launch_helper(FuncT &&func, ArgSelector &selector,
|
|
3137
|
+
std::index_sequence<Index...>) {
|
|
3138
|
+
func(selector.template get<Index>()...);
|
|
3139
|
+
}
|
|
3140
|
+
static void set_execution_config(dim3 group_range, dim3 local_range,
|
|
3141
|
+
unsigned int local_mem_size,
|
|
3142
|
+
queue_ptr que) {
|
|
3143
|
+
if (que) {
|
|
3144
|
+
_que = que;
|
|
3145
|
+
} else {
|
|
3146
|
+
_que = &get_default_queue();
|
|
3147
|
+
}
|
|
3148
|
+
_nr = sycl::nd_range<3>(
|
|
3149
|
+
static_cast<sycl::range<3>>(group_range * local_range),
|
|
3150
|
+
static_cast<sycl::range<3>>(local_range));
|
|
3151
|
+
_local_mem_size = local_mem_size;
|
|
3152
|
+
|
|
3153
|
+
|
|
3154
|
+
};
|
|
3155
|
+
static inline std::mutex kernel_function_ptr_map_mutex;
|
|
3156
|
+
|
|
3157
|
+
public:
|
|
3158
|
+
/// Variables for storing execution configuration.
|
|
3159
|
+
static inline thread_local sycl::queue *_que = nullptr;
|
|
3160
|
+
static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
|
|
3161
|
+
static inline thread_local unsigned int _local_mem_size = 0;
|
|
3162
|
+
/// Map for retrieving launchable functor from a raw pointer.
|
|
3163
|
+
static inline std::map<
|
|
3164
|
+
const void *,
|
|
3165
|
+
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
|
|
3166
|
+
kernel_function_ptr_map = {};
|
|
3167
|
+
|
|
3168
|
+
/// Registers a kernel function pointer with a corresponding launchable
|
|
3169
|
+
/// functor.
|
|
3170
|
+
/// \param [in] func Pointer to the kernel function.
|
|
3171
|
+
/// \param [in] launcher Functor to handle kernel invocation.
|
|
3172
|
+
static void register_kernel_ptr(
|
|
3173
|
+
const void *func,
|
|
3174
|
+
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
|
|
3175
|
+
launcher) {
|
|
3176
|
+
std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
|
|
3177
|
+
kernel_function_ptr_map[func] = std::move(launcher);
|
|
3178
|
+
}
|
|
3179
|
+
/// Launches a kernel function with arguments provided directly through
|
|
3180
|
+
/// kernel function wrapper.
|
|
3181
|
+
/// \tparam FuncT Type of the kernel function wrapper.
|
|
3182
|
+
/// \tparam ArgsT Types of kernel arguments.
|
|
3183
|
+
/// \param [in] func Pointer to the kernel function wrapper.
|
|
3184
|
+
/// \param [in] group_range SYCL group range.
|
|
3185
|
+
/// \param [in] local_range SYCL local range.
|
|
3186
|
+
/// \param [in] local_mem_size The size of local memory required by the
|
|
3187
|
+
/// kernel function. \param [in] que SYCL queue used to execute kernel.
|
|
3188
|
+
/// \param [in] args Kernel arguments.
|
|
3189
|
+
template <typename FuncT, typename... ArgsT>
|
|
3190
|
+
static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
|
|
3191
|
+
launch(FuncT *func, dim3 group_range, dim3 local_range,
|
|
3192
|
+
unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
|
|
3193
|
+
set_execution_config(group_range, local_range, local_mem_size, que);
|
|
3194
|
+
func(args...);
|
|
3195
|
+
}
|
|
3196
|
+
/// Launches a kernel function through registered kernel function
|
|
3197
|
+
/// wrapper. \param [in] func Pointer to the registered kernel function
|
|
3198
|
+
/// wrapper. \param [in] group_range SYCL group range. \param [in]
|
|
3199
|
+
/// local_range SYCL local range. \param [in] args Array of pointers to
|
|
3200
|
+
/// kernel arguments. \param [in] local_mem_size The size of local
|
|
3201
|
+
/// memory required by the kernel function. \param [in] que SYCL queue
|
|
3202
|
+
/// used to execute kernel.
|
|
3203
|
+
static void launch(const void *func, dim3 group_range, dim3 local_range,
|
|
3204
|
+
void **args, unsigned int local_mem_size,
|
|
3205
|
+
queue_ptr que) {
|
|
3206
|
+
std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
|
|
3207
|
+
auto Iter = kernel_function_ptr_map.find(func);
|
|
3208
|
+
if (Iter == kernel_function_ptr_map.end()) {
|
|
3209
|
+
throw std::runtime_error("dpct::launch() : no registered "
|
|
3210
|
+
"kernel function wrapper found.");
|
|
3211
|
+
}
|
|
3212
|
+
(Iter->second)(group_range, local_range, args, local_mem_size, que);
|
|
3213
|
+
}
|
|
3214
|
+
/// Launches a kernel function with packed arguments through kernel
|
|
3215
|
+
/// function wrapper.
|
|
3216
|
+
/// \tparam FuncT Type of the kernel function wrapper.
|
|
3217
|
+
/// \param [in] func Pointer to the kernel function wrapper.
|
|
3218
|
+
/// \param [in] group_range SYCL group range.
|
|
3219
|
+
/// \param [in] local_range SYCL local range.
|
|
3220
|
+
/// \param [in] args Array of pointers to kernel arguments.
|
|
3221
|
+
/// \param [in] local_mem_size The size of local memory required by the
|
|
3222
|
+
/// kernel function. \param [in] que SYCL queue used to execute kernel.
|
|
3223
|
+
template <typename FuncT>
|
|
3224
|
+
static std::enable_if_t<std::is_function_v<FuncT>, void>
|
|
3225
|
+
launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
|
|
3226
|
+
unsigned int local_mem_size, queue_ptr que) {
|
|
3227
|
+
constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
|
|
3228
|
+
set_execution_config(group_range, local_range, local_mem_size, que);
|
|
3229
|
+
args_selector<p_num, p_num, FuncT> selector(args, nullptr);
|
|
3230
|
+
launch_helper(func, selector, std::make_index_sequence<p_num>{});
|
|
3231
|
+
}
|
|
3232
|
+
}; // COPY from DPCT head file
|
|
3233
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
|
|
3234
|
+
|
|
3235
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
|
|
3236
|
+
template <typename T>
|
|
3237
|
+
T select_from_sub_group(
|
|
3238
|
+
sycl::sub_group g,
|
|
3239
|
+
T x,
|
|
3240
|
+
int remote_local_id,
|
|
3241
|
+
int logical_sub_group_size = 32) {
|
|
3242
|
+
unsigned int start_index = g.get_local_linear_id() /
|
|
3243
|
+
logical_sub_group_size *
|
|
3244
|
+
logical_sub_group_size;
|
|
3245
|
+
return sycl::select_from_group(
|
|
3246
|
+
g, x, start_index + remote_local_id % logical_sub_group_size);
|
|
3247
|
+
}
|
|
3248
|
+
|
|
3249
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
|
|
3250
|
+
template <typename T>
|
|
3251
|
+
void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
|
|
3252
|
+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
|
|
3253
|
+
int lane = sg.get_local_linear_id();
|
|
3254
|
+
|
|
3255
|
+
int lane_group8_row = lane / 8;
|
|
3256
|
+
int lane_group8_col = lane % 8;
|
|
3257
|
+
|
|
3258
|
+
if (!trans) {
|
|
3259
|
+
// calculate the source lane
|
|
3260
|
+
int src_lane = 2 * lane_group8_row;
|
|
3261
|
+
if (lane_group8_col >= 4)
|
|
3262
|
+
src_lane += 1;
|
|
3263
|
+
|
|
3264
|
+
// Broadcast the address from the source lane
|
|
3265
|
+
auto recv_addr_uintp =
|
|
3266
|
+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
|
|
3267
|
+
|
|
3268
|
+
// Cast the received address from uintptr_t to the type of 'm'
|
|
3269
|
+
auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
|
|
3270
|
+
|
|
3271
|
+
// Non-transposed load
|
|
3272
|
+
*m = recv_addr[lane_group8_col % 4];
|
|
3273
|
+
} else {
|
|
3274
|
+
// calculate the source lane
|
|
3275
|
+
int src_lane = (lane % 4) * 2;
|
|
3276
|
+
|
|
3277
|
+
// Broadcast the address from the source lane
|
|
3278
|
+
auto recv_addr_uintp_1 =
|
|
3279
|
+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
|
|
3280
|
+
auto recv_addr_uintp_2 =
|
|
3281
|
+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
|
|
3282
|
+
|
|
3283
|
+
// Cast the received address from uintptr_t to 'half *'
|
|
3284
|
+
auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
|
|
3285
|
+
auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
|
|
3286
|
+
|
|
3287
|
+
// Transposed load
|
|
3288
|
+
int index = lane / 4;
|
|
3289
|
+
sycl::half val0 = recv_addr_1[index];
|
|
3290
|
+
sycl::half val1 = recv_addr_2[index];
|
|
3291
|
+
|
|
3292
|
+
// Combine the two 16-bits into one 32-bit value
|
|
3293
|
+
sycl::half2 val = sycl::half2(val0, val1);
|
|
3294
|
+
*m = *reinterpret_cast<T*>(&val);
|
|
3295
|
+
}
|
|
3296
|
+
}
|
|
3297
|
+
|
|
3298
|
+
template <typename T>
|
|
3299
|
+
void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
|
|
3300
|
+
// Load 1st matrix
|
|
3301
|
+
ldmatrix(addr, m1, trans, 0);
|
|
3302
|
+
// Load 2nd matrix
|
|
3303
|
+
ldmatrix(addr, m2, trans, 1);
|
|
3304
|
+
}
|
|
3305
|
+
|
|
3306
|
+
template <typename T>
|
|
3307
|
+
void ldmatrix(
|
|
3308
|
+
uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
|
|
3309
|
+
// Load 1st matrix
|
|
3310
|
+
ldmatrix(addr, m1, trans, 0);
|
|
3311
|
+
// Load 2nd matrix
|
|
3312
|
+
ldmatrix(addr, m2, trans, 1);
|
|
3313
|
+
// Load 3rd matrix
|
|
3314
|
+
ldmatrix(addr, m3, trans, 2);
|
|
3315
|
+
// Load 4th matrix
|
|
3316
|
+
ldmatrix(addr, m4, trans, 3);
|
|
3317
|
+
}
|
|
3318
|
+
|
|
3319
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
|
|
3320
|
+
|
|
3321
|
+
/// A helper struct that defines the pack type for the input matrix
|
|
3322
|
+
/// fragments
|
|
3323
|
+
/// of mma() function based on the type of input matrix fragments.
|
|
3324
|
+
/// The MMAType struct is specialized for different types of input matrices.
|
|
3325
|
+
/// Currently, the specialization for f16, bf16 and s8 types is defined
|
|
3326
|
+
/// below. \tparam [in] T The type of the input matrix fragments
|
|
3327
|
+
template <typename T>
|
|
3328
|
+
struct MMAType {
|
|
3329
|
+
using PackType = uint32_t;
|
|
3330
|
+
};
|
|
3331
|
+
|
|
3332
|
+
/// Each work item of a sub-group (limited to size 32) calling this function
|
|
3333
|
+
/// calculates a subset fragment for the output matrix D using MAD operation
|
|
3334
|
+
/// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
|
|
3335
|
+
/// types:
|
|
3336
|
+
/// - m8n8k4 (f32.f16.f16.f32)
|
|
3337
|
+
/// - m8n8k16 (s32.s8.s8.s32)
|
|
3338
|
+
/// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
|
|
3339
|
+
/// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
|
|
3340
|
+
/// - m16n8k32 (s32.s8.s8.s32)
|
|
3341
|
+
/// Here, m, n & k define the shapes of A, B & C matrices respectively
|
|
3342
|
+
/// (A = [m x k], B = [k x n], C = [m x n]).
|
|
3343
|
+
/// \tparam [in] M The rows of A, C & D matrices
|
|
3344
|
+
/// \tparam [in] N The columns of B, C, D matrices
|
|
3345
|
+
/// \tparam [in] K The columns & rows of A & B matrices respectively
|
|
3346
|
+
/// \tparam [in] ABType The type of the input matrix (A & B) fragment
|
|
3347
|
+
/// \tparam [in] CDType The type of the output matrix (C & D) fragment
|
|
3348
|
+
/// \param [out] d_mat_frag The fragment of the output matrix D to store the
|
|
3349
|
+
/// result of A * B + C
|
|
3350
|
+
/// \param [in] a_mat_frag The fragment of the input matrix A to be
|
|
3351
|
+
/// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
|
|
3352
|
+
/// the input matrix B to be multiplied with A matrix fragment \param [in]
|
|
3353
|
+
/// c_mat_frag The fragment of the input matrix C to be added with the
|
|
3354
|
+
/// result of A * B fragments
|
|
3355
|
+
template <int M, int N, int K, typename ABType, typename CDType>
|
|
3356
|
+
void mma(
|
|
3357
|
+
volatile void** d_mat_frag,
|
|
3358
|
+
void* a_mat_frag,
|
|
3359
|
+
void* b_mat_frag,
|
|
3360
|
+
void* c_mat_frag) {
|
|
3361
|
+
auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
|
|
3362
|
+
auto a =
|
|
3363
|
+
reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
|
|
3364
|
+
auto b =
|
|
3365
|
+
reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
|
|
3366
|
+
auto c = reinterpret_cast<CDType*>(c_mat_frag);
|
|
3367
|
+
|
|
3368
|
+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
|
|
3369
|
+
int lane = sg.get_local_linear_id();
|
|
3370
|
+
|
|
3371
|
+
static_assert(
|
|
3372
|
+
(M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
|
|
3373
|
+
(M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
|
|
3374
|
+
(M == 16 && N == 8 && K == 32),
|
|
3375
|
+
"Unsupported MMA shape!");
|
|
3376
|
+
|
|
3377
|
+
short row_load_offset = 4 * (lane >> 2);
|
|
3378
|
+
short col_load_offset = 8 * (lane % 4);
|
|
3379
|
+
|
|
3380
|
+
if constexpr (M == 8 && N == 8 && K == 4) {
|
|
3381
|
+
if constexpr (std::is_floating_point_v<CDType>) {
|
|
3382
|
+
col_load_offset = row_load_offset % 16;
|
|
3383
|
+
|
|
3384
|
+
// Init D matrix with fragments of C matrix
|
|
3385
|
+
*d[0] = c[0];
|
|
3386
|
+
*d[1] = c[1];
|
|
3387
|
+
*d[2] = c[2];
|
|
3388
|
+
*d[3] = c[3];
|
|
3389
|
+
*d[4] = c[4];
|
|
3390
|
+
*d[5] = c[5];
|
|
3391
|
+
*d[6] = c[6];
|
|
3392
|
+
*d[7] = c[7];
|
|
3393
|
+
|
|
3394
|
+
// Calculate the row and col offset indices to iterate through the row
|
|
3395
|
+
// & col fragments of A & B matrices
|
|
3396
|
+
int r_ind = (lane % 2) ? 1 : 0;
|
|
3397
|
+
int c_ind = ((lane % 4) / 2) ? 2 : 0;
|
|
3398
|
+
|
|
3399
|
+
// Each sub-group is responsible for computing a fragment size of 8*8
|
|
3400
|
+
// elements of matrix D for each of 4 MMA computations.
|
|
3401
|
+
// Each work item computes 8 elements of matrix D by gathering
|
|
3402
|
+
// their corresponding col & row matrix fragments of length k (4)
|
|
3403
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3404
|
+
// row0 = (i % 4) if (lane < 16) else (i % 4) + 4
|
|
3405
|
+
// col0 = (lane % 4)
|
|
3406
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3407
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3408
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3409
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3410
|
+
|
|
3411
|
+
for (int i = 0; i < 4; i++) {
|
|
3412
|
+
// Load partial fragment from col0 of matrix A ({a0, a1})
|
|
3413
|
+
recv_a[0] =
|
|
3414
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3415
|
+
// Load partial fragment from col0 of matrix A ({a2, a3})
|
|
3416
|
+
recv_a[1] =
|
|
3417
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3418
|
+
|
|
3419
|
+
// Load partial fragment from row0 of matrix B ({b0, b1})
|
|
3420
|
+
recv_b[0] =
|
|
3421
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3422
|
+
// Load partial fragment from row0 of matrix B ({b2, b3})
|
|
3423
|
+
recv_b[1] =
|
|
3424
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
|
3425
|
+
|
|
3426
|
+
auto ra = reinterpret_cast<ABType*>(recv_a);
|
|
3427
|
+
auto rb = reinterpret_cast<ABType*>(recv_b);
|
|
3428
|
+
|
|
3429
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3430
|
+
// fragments and adds it to the corresponding D matrix fragment (for
|
|
3431
|
+
// even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
|
|
3432
|
+
// a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
|
|
3433
|
+
// * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
|
|
3434
|
+
// b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
|
|
3435
|
+
// d3 += col1{ a3 } * row0{ b3 }
|
|
3436
|
+
*d[0] +=
|
|
3437
|
+
static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
|
|
3438
|
+
*d[1] += static_cast<float>(ra[r_ind]) *
|
|
3439
|
+
static_cast<float>(rb[c_ind + 1]);
|
|
3440
|
+
*d[2] += static_cast<float>(ra[r_ind + 2]) *
|
|
3441
|
+
static_cast<float>(rb[c_ind]);
|
|
3442
|
+
*d[3] += static_cast<float>(ra[r_ind + 2]) *
|
|
3443
|
+
static_cast<float>(rb[c_ind + 1]);
|
|
3444
|
+
|
|
3445
|
+
// Load partial fragment from row1 of matrix B ({b0, b1})
|
|
3446
|
+
recv_b[0] =
|
|
3447
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
|
|
3448
|
+
// Load partial fragment from row1 of matrix B ({b2, b3})
|
|
3449
|
+
recv_b[1] =
|
|
3450
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
|
|
3451
|
+
|
|
3452
|
+
// (for even work item indices)
|
|
3453
|
+
// d0 += col0{ a0 } * row1{ b0 }
|
|
3454
|
+
// d1 += col0{ a0 } * row1{ b1 }
|
|
3455
|
+
// d2 += col1{ a2 } * row1{ b0 }
|
|
3456
|
+
// d3 += col1{ a2 } * row1{ b1 }
|
|
3457
|
+
// (for odd work item indices)
|
|
3458
|
+
// d0 += col0{ a1 } * row1{ b2 }
|
|
3459
|
+
// d1 += col0{ a1 } * row1{ b3 }
|
|
3460
|
+
// d2 += col1{ a3 } * row1{ b2 }
|
|
3461
|
+
// d3 += col1{ a3 } * row1{ b3 }
|
|
3462
|
+
*d[4] +=
|
|
3463
|
+
static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
|
|
3464
|
+
*d[5] += static_cast<float>(ra[r_ind]) *
|
|
3465
|
+
static_cast<float>(rb[c_ind + 1]);
|
|
3466
|
+
*d[6] += static_cast<float>(ra[r_ind + 2]) *
|
|
3467
|
+
static_cast<float>(rb[c_ind]);
|
|
3468
|
+
*d[7] += static_cast<float>(ra[r_ind + 2]) *
|
|
3469
|
+
static_cast<float>(rb[c_ind + 1]);
|
|
3470
|
+
}
|
|
3471
|
+
}
|
|
3472
|
+
} else if constexpr (M == 8 && N == 8 && K == 16) {
|
|
3473
|
+
if constexpr (std::is_integral_v<ABType>) {
|
|
3474
|
+
// Init D matrix with fragments of C matrix
|
|
3475
|
+
*d[0] = c[0];
|
|
3476
|
+
*d[1] = c[1];
|
|
3477
|
+
|
|
3478
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3479
|
+
// elements of matrix D.
|
|
3480
|
+
// Each work item computes 2 elements of matrix D by gathering
|
|
3481
|
+
// their corresponding row & col matrix fragments of length k (16)
|
|
3482
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3483
|
+
// row0 = ((lane % 4) * 4) + i
|
|
3484
|
+
// col0 = (lane >> 2)
|
|
3485
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3486
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3487
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3488
|
+
for (int i = 0; i < 4; i++) {
|
|
3489
|
+
typename MMAType<ABType>::PackType recv_a, recv_b[2];
|
|
3490
|
+
|
|
3491
|
+
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
|
3492
|
+
recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3493
|
+
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
|
3494
|
+
recv_b[0] =
|
|
3495
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3496
|
+
// Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
|
|
3497
|
+
recv_b[1] =
|
|
3498
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
|
3499
|
+
|
|
3500
|
+
auto a = reinterpret_cast<ABType*>(&recv_a);
|
|
3501
|
+
auto b = reinterpret_cast<ABType*>(recv_b);
|
|
3502
|
+
|
|
3503
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3504
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3505
|
+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
|
3506
|
+
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
|
|
3507
|
+
// a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
|
|
3508
|
+
// col1{ b0, b1, b2, b3 }
|
|
3509
|
+
for (int j = 0; j < 4; j++) {
|
|
3510
|
+
*d[0] += a[j] * b[j];
|
|
3511
|
+
*d[1] += a[j] * b[j + 4];
|
|
3512
|
+
}
|
|
3513
|
+
}
|
|
3514
|
+
}
|
|
3515
|
+
} else if constexpr (M == 16 && N == 8 && K == 8) {
|
|
3516
|
+
if constexpr (std::is_floating_point_v<CDType>) {
|
|
3517
|
+
// Init D matrix fragment with C matrix fragment
|
|
3518
|
+
*d[0] = c[0];
|
|
3519
|
+
*d[1] = c[1];
|
|
3520
|
+
*d[2] = c[2];
|
|
3521
|
+
*d[3] = c[3];
|
|
3522
|
+
|
|
3523
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3524
|
+
// elements of matrix D.
|
|
3525
|
+
// Each work item computes 4 elements of matrix D by gathering
|
|
3526
|
+
// their corresponding row & col matrix fragments of length k (8)
|
|
3527
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3528
|
+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
|
3529
|
+
// col0 = (lane % 4) * 2 + (i & 0x1)
|
|
3530
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3531
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3532
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3533
|
+
for (int i = 0; i < 4; i++) {
|
|
3534
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3535
|
+
|
|
3536
|
+
// Load partial fragment from row0 of matrix A ({a0, a1})
|
|
3537
|
+
recv_a[0] =
|
|
3538
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3539
|
+
// Load partial fragment from row1 of matrix A ({a2, a3})
|
|
3540
|
+
recv_a[1] =
|
|
3541
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3542
|
+
// Load partial fragment from col0 of matrix B ({b0, b1})
|
|
3543
|
+
recv_b[0] =
|
|
3544
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3545
|
+
// Load partial fragment from col1 of matrix B ({b0, b1})
|
|
3546
|
+
recv_b[1] =
|
|
3547
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
|
3548
|
+
|
|
3549
|
+
auto ra = reinterpret_cast<ABType*>(recv_a);
|
|
3550
|
+
auto rb = reinterpret_cast<ABType*>(recv_b);
|
|
3551
|
+
|
|
3552
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3553
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3554
|
+
// += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
|
|
3555
|
+
// b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
|
|
3556
|
+
// } * col1{ b0, b1 }
|
|
3557
|
+
for (int j = 0; j < 2; j++) {
|
|
3558
|
+
*d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
|
|
3559
|
+
*d[1] +=
|
|
3560
|
+
static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
|
|
3561
|
+
*d[2] +=
|
|
3562
|
+
static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
|
|
3563
|
+
*d[3] +=
|
|
3564
|
+
static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
|
|
3565
|
+
}
|
|
3566
|
+
}
|
|
3567
|
+
}
|
|
3568
|
+
} else if constexpr (M == 16 && N == 8 && K == 16) {
|
|
3569
|
+
if constexpr (std::is_floating_point_v<CDType>) {
|
|
3570
|
+
// Init D matrix fragment with C matrix fragment
|
|
3571
|
+
*d[0] = c[0];
|
|
3572
|
+
*d[1] = c[1];
|
|
3573
|
+
*d[2] = c[2];
|
|
3574
|
+
*d[3] = c[3];
|
|
3575
|
+
|
|
3576
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3577
|
+
// elements of matrix D.
|
|
3578
|
+
// Each work item computes 4 elements of matrix D by gathering
|
|
3579
|
+
// their corresponding row & col matrix fragments of length k (8)
|
|
3580
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3581
|
+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
|
3582
|
+
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
|
|
3583
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3584
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3585
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3586
|
+
for (int i = 0; i < 4; i++) {
|
|
3587
|
+
typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
|
|
3588
|
+
|
|
3589
|
+
// Load partial fragment from row0 of matrix A ({a0, a1})
|
|
3590
|
+
recv_a[0] =
|
|
3591
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3592
|
+
// Load partial fragment from row0 of matrix A ({a2, a3})
|
|
3593
|
+
recv_a[1] =
|
|
3594
|
+
dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
|
|
3595
|
+
// Load partial fragment from row1 of matrix A ({a0, a1})
|
|
3596
|
+
recv_a[2] =
|
|
3597
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3598
|
+
// Load partial fragment from row1 of matrix A ({a2, a3})
|
|
3599
|
+
recv_a[3] =
|
|
3600
|
+
dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
|
|
3601
|
+
|
|
3602
|
+
// Load partial fragment from col0 of matrix B ({b0, b1})
|
|
3603
|
+
recv_b[0] =
|
|
3604
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3605
|
+
// Load partial fragment from col0 of matrix B ({b2, b3})
|
|
3606
|
+
recv_b[1] =
|
|
3607
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
|
3608
|
+
// Load partial fragment from col1 of matrix B ({b0, b1})
|
|
3609
|
+
recv_b[2] =
|
|
3610
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
|
|
3611
|
+
// Load partial fragment from col1 of matrix B ({b2, b3})
|
|
3612
|
+
recv_b[3] =
|
|
3613
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
|
|
3614
|
+
|
|
3615
|
+
auto ra = reinterpret_cast<ABType*>(recv_a);
|
|
3616
|
+
auto rb = reinterpret_cast<ABType*>(recv_b);
|
|
3617
|
+
|
|
3618
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3619
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3620
|
+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
|
3621
|
+
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
|
|
3622
|
+
// a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
|
|
3623
|
+
// col1{ b0, b1, b2, b3 }
|
|
3624
|
+
for (int j = 0; j < 4; j++) {
|
|
3625
|
+
*d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
|
|
3626
|
+
*d[1] +=
|
|
3627
|
+
static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
|
|
3628
|
+
*d[2] +=
|
|
3629
|
+
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
|
|
3630
|
+
*d[3] += static_cast<CDType>(ra[j + 4]) *
|
|
3631
|
+
static_cast<CDType>(rb[j + 4]);
|
|
3632
|
+
}
|
|
3633
|
+
}
|
|
3634
|
+
} else if constexpr (std::is_integral_v<ABType>) {
|
|
3635
|
+
// Init D matrix with fragments of C matrix
|
|
3636
|
+
*d[0] = c[0];
|
|
3637
|
+
*d[1] = c[1];
|
|
3638
|
+
*d[2] = c[2];
|
|
3639
|
+
*d[3] = c[3];
|
|
3640
|
+
|
|
3641
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3642
|
+
// elements of matrix D.
|
|
3643
|
+
// Each work item computes 4 elements of matrix D by gathering
|
|
3644
|
+
// their corresponding row & col matrix fragments of length k (8)
|
|
3645
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3646
|
+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
|
3647
|
+
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
|
|
3648
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3649
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3650
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3651
|
+
for (int i = 0; i < 4; i++) {
|
|
3652
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3653
|
+
|
|
3654
|
+
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
|
3655
|
+
recv_a[0] =
|
|
3656
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3657
|
+
// Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
|
|
3658
|
+
recv_a[1] =
|
|
3659
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3660
|
+
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
|
3661
|
+
recv_b[0] =
|
|
3662
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3663
|
+
// Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
|
|
3664
|
+
recv_b[1] =
|
|
3665
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
|
3666
|
+
|
|
3667
|
+
auto ra = reinterpret_cast<ABType*>(recv_a);
|
|
3668
|
+
auto rb = reinterpret_cast<ABType*>(recv_b);
|
|
3669
|
+
|
|
3670
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3671
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3672
|
+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
|
3673
|
+
// a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
|
|
3674
|
+
// a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
|
|
3675
|
+
// col1{ b4, b5, b6, b7 }
|
|
3676
|
+
for (int i = 0; i < 4; i++) {
|
|
3677
|
+
*d[0] += ra[i] * rb[i];
|
|
3678
|
+
*d[1] += ra[i] * rb[i + 4];
|
|
3679
|
+
*d[2] += ra[i + 4] * rb[i];
|
|
3680
|
+
*d[3] += ra[i + 4] * rb[i + 4];
|
|
3681
|
+
}
|
|
3682
|
+
}
|
|
3683
|
+
}
|
|
3684
|
+
} else if constexpr (M == 16 && N == 8 && K == 32) {
|
|
3685
|
+
if constexpr (std::is_integral_v<ABType>) {
|
|
3686
|
+
// Init D matrix with fragments of C matrix
|
|
3687
|
+
*d[0] = c[0];
|
|
3688
|
+
*d[1] = c[1];
|
|
3689
|
+
*d[2] = c[2];
|
|
3690
|
+
*d[3] = c[3];
|
|
3691
|
+
|
|
3692
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3693
|
+
// elements of matrix D.
|
|
3694
|
+
// Each work item computes 4 elements of matrix D by gathering
|
|
3695
|
+
// their corresponding row & col matrix fragments of length k (32)
|
|
3696
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3697
|
+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
|
3698
|
+
// col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
|
|
3699
|
+
// & 0x3) As each row & col fragment of A & B matrices is distributed
|
|
3700
|
+
// across 4 work items, each iteration of below loop loads a partial
|
|
3701
|
+
// fragment of matrix A (row) and matrix B (col) using the row & col
|
|
3702
|
+
// offsets.
|
|
3703
|
+
for (int i = 0; i < 4; i++) {
|
|
3704
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3705
|
+
|
|
3706
|
+
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
|
3707
|
+
recv_a[0] =
|
|
3708
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3709
|
+
// Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
|
|
3710
|
+
recv_a[1] =
|
|
3711
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3712
|
+
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
|
3713
|
+
recv_b[0] =
|
|
3714
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3715
|
+
// Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
|
|
3716
|
+
recv_b[1] =
|
|
3717
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
|
3718
|
+
|
|
3719
|
+
auto a = reinterpret_cast<ABType*>(recv_a);
|
|
3720
|
+
auto b = reinterpret_cast<ABType*>(recv_b);
|
|
3721
|
+
|
|
3722
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3723
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3724
|
+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
|
3725
|
+
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
|
|
3726
|
+
// a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
|
|
3727
|
+
// col1{ b0, b1, b2, b3 }
|
|
3728
|
+
for (int j = 0; j < 4; j++) {
|
|
3729
|
+
*d[0] += a[j] * b[j];
|
|
3730
|
+
*d[1] += a[j] * b[j + 4];
|
|
3731
|
+
*d[2] += a[j + 4] * b[j];
|
|
3732
|
+
*d[3] += a[j + 4] * b[j + 4];
|
|
3733
|
+
}
|
|
3734
|
+
}
|
|
3735
|
+
|
|
3736
|
+
for (int i = 0; i < 4; i++) {
|
|
3737
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3738
|
+
|
|
3739
|
+
// Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
|
|
3740
|
+
recv_a[0] =
|
|
3741
|
+
dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
|
|
3742
|
+
// Load partial fragment from row1 of matrix A ({a12, a13, a14,
|
|
3743
|
+
// a15})
|
|
3744
|
+
recv_a[1] =
|
|
3745
|
+
dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
|
|
3746
|
+
// Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
|
|
3747
|
+
recv_b[0] =
|
|
3748
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
|
3749
|
+
// Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
|
|
3750
|
+
recv_b[1] =
|
|
3751
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
|
|
3752
|
+
|
|
3753
|
+
auto a = reinterpret_cast<ABType*>(recv_a);
|
|
3754
|
+
auto b = reinterpret_cast<ABType*>(recv_b);
|
|
3755
|
+
|
|
3756
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3757
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3758
|
+
// += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
|
|
3759
|
+
// a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
|
|
3760
|
+
// a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
|
|
3761
|
+
// a15 } * col1{ b4, b5, b6, b7 }
|
|
3762
|
+
for (int j = 0; j < 4; j++) {
|
|
3763
|
+
*d[0] += a[j] * b[j];
|
|
3764
|
+
*d[1] += a[j] * b[j + 4];
|
|
3765
|
+
*d[2] += a[j + 4] * b[j];
|
|
3766
|
+
*d[3] += a[j + 4] * b[j + 4];
|
|
3767
|
+
}
|
|
3768
|
+
}
|
|
3769
|
+
}
|
|
3770
|
+
}
|
|
3771
|
+
}
|
|
2955
3772
|
} // COPY from DPCT head files
|
|
2956
3773
|
|
|
2957
3774
|
#endif // GGML_SYCL_DPCT_HELPER_HPP
|