whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -15,18 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
#include <sycl/sycl.hpp>
|
|
17
17
|
#include <sycl/half_type.hpp>
|
|
18
|
-
#include <syclcompat/math.hpp>
|
|
19
|
-
#include <map>
|
|
20
|
-
|
|
21
|
-
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
|
22
18
|
#include <oneapi/mkl.hpp>
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
namespace math = mkl;
|
|
26
|
-
}
|
|
27
|
-
#else
|
|
28
|
-
#include <oneapi/math.hpp>
|
|
29
|
-
#endif
|
|
19
|
+
|
|
20
|
+
#include <map>
|
|
30
21
|
|
|
31
22
|
#include "ggml.h"
|
|
32
23
|
|
|
@@ -92,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
|
|
92
83
|
}
|
|
93
84
|
|
|
94
85
|
template <typename Ts> struct matrix_info_t {
|
|
95
|
-
oneapi::
|
|
86
|
+
oneapi::mkl::transpose transpose_info[2];
|
|
96
87
|
Ts value_info[2];
|
|
97
88
|
std::int64_t size_info[3];
|
|
98
89
|
std::int64_t ld_info[3];
|
|
99
90
|
std::int64_t groupsize_info;
|
|
100
91
|
};
|
|
101
92
|
|
|
102
|
-
inline auto get_onemath_backend(sycl::queue& queue)
|
|
103
|
-
#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
104
|
-
-> sycl::queue&
|
|
105
|
-
#endif
|
|
106
|
-
{
|
|
107
|
-
// If the backend is known at compile-time, use oneMath backend_selector to use
|
|
108
|
-
// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
|
|
109
|
-
// fallback to runtime dispatching.
|
|
110
|
-
#if defined(GGML_SYCL_NVIDIA)
|
|
111
|
-
return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
|
|
112
|
-
#elif defined(GGML_SYCL_AMD)
|
|
113
|
-
return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
|
|
114
|
-
#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
115
|
-
return queue;
|
|
116
|
-
#else
|
|
117
|
-
static_assert(false, "Unsupported backend");
|
|
118
|
-
#endif
|
|
119
|
-
}
|
|
120
|
-
|
|
121
93
|
namespace dpct
|
|
122
94
|
{
|
|
123
95
|
typedef sycl::queue *queue_ptr;
|
|
@@ -1735,7 +1707,7 @@ namespace dpct
|
|
|
1735
1707
|
namespace detail
|
|
1736
1708
|
{
|
|
1737
1709
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1738
|
-
inline void gemm_impl(sycl::queue & q, oneapi::
|
|
1710
|
+
inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
|
1739
1711
|
int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
|
|
1740
1712
|
const void * beta, void * c, int ldc) {
|
|
1741
1713
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
@@ -1743,7 +1715,7 @@ namespace dpct
|
|
|
1743
1715
|
auto data_a = get_memory<const Ta>(a);
|
|
1744
1716
|
auto data_b = get_memory<const Tb>(b);
|
|
1745
1717
|
auto data_c = get_memory<Tc>(c);
|
|
1746
|
-
oneapi::
|
|
1718
|
+
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
|
|
1747
1719
|
lda, data_b, ldb, beta_value, data_c, ldc);
|
|
1748
1720
|
}
|
|
1749
1721
|
|
|
@@ -1775,7 +1747,7 @@ namespace dpct
|
|
|
1775
1747
|
};
|
|
1776
1748
|
|
|
1777
1749
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1778
|
-
inline void gemm_batch_impl(sycl::queue & q, oneapi::
|
|
1750
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
|
1779
1751
|
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
|
1780
1752
|
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
|
1781
1753
|
matrix_info_t<float> * matrix_info) {
|
|
@@ -1794,8 +1766,8 @@ namespace dpct
|
|
|
1794
1766
|
matrix_info->ld_info[2] = ldc;
|
|
1795
1767
|
matrix_info->groupsize_info = batch_size;
|
|
1796
1768
|
|
|
1797
|
-
sycl::event e = oneapi::
|
|
1798
|
-
|
|
1769
|
+
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
1770
|
+
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
|
1799
1771
|
matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
|
|
1800
1772
|
reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
|
1801
1773
|
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
@@ -1804,7 +1776,7 @@ namespace dpct
|
|
|
1804
1776
|
}
|
|
1805
1777
|
|
|
1806
1778
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1807
|
-
inline void gemm_batch_impl(sycl::queue & q, oneapi::
|
|
1779
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
|
1808
1780
|
int m, int n, int k, const void * alpha, const void * a, int lda,
|
|
1809
1781
|
long long int stride_a, const void * b, int ldb, long long int stride_b,
|
|
1810
1782
|
const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
|
|
@@ -1813,7 +1785,7 @@ namespace dpct
|
|
|
1813
1785
|
auto data_a = get_memory<const Ta>(a);
|
|
1814
1786
|
auto data_b = get_memory<const Tb>(b);
|
|
1815
1787
|
auto data_c = get_memory<Tc>(c);
|
|
1816
|
-
oneapi::
|
|
1788
|
+
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
|
|
1817
1789
|
data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
|
|
1818
1790
|
data_c, ldc, stride_c, batch_size);
|
|
1819
1791
|
}
|
|
@@ -2300,7 +2272,7 @@ namespace dpct
|
|
|
2300
2272
|
sycl::range<3>(x, y, 1), direction);
|
|
2301
2273
|
}
|
|
2302
2274
|
|
|
2303
|
-
inline void gemm(sycl::queue & q, oneapi::
|
|
2275
|
+
inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
|
|
2304
2276
|
int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
|
|
2305
2277
|
library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
2306
2278
|
library_data_t scaling_type) {
|
|
@@ -2367,7 +2339,7 @@ namespace dpct
|
|
|
2367
2339
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2368
2340
|
library_data_t::real_float, library_data_t::real_float):
|
|
2369
2341
|
{
|
|
2370
|
-
detail::gemm_impl<oneapi::
|
|
2342
|
+
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
|
2371
2343
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
2372
2344
|
break;
|
|
2373
2345
|
}
|
|
@@ -2406,7 +2378,7 @@ namespace dpct
|
|
|
2406
2378
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2407
2379
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2408
2380
|
{
|
|
2409
|
-
detail::gemm_impl<oneapi::
|
|
2381
|
+
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
|
2410
2382
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
2411
2383
|
break;
|
|
2412
2384
|
}
|
|
@@ -2448,7 +2420,7 @@ namespace dpct
|
|
|
2448
2420
|
/// \param [in] ldc Leading dimension of C.
|
|
2449
2421
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
2450
2422
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
2451
|
-
inline void gemm_batch(sycl::queue & q, oneapi::
|
|
2423
|
+
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
|
2452
2424
|
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
|
2453
2425
|
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
|
2454
2426
|
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
|
@@ -2486,7 +2458,7 @@ namespace dpct
|
|
|
2486
2458
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2487
2459
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2488
2460
|
{
|
|
2489
|
-
detail::gemm_batch_impl<oneapi::
|
|
2461
|
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
|
2490
2462
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
2491
2463
|
break;
|
|
2492
2464
|
}
|
|
@@ -2494,7 +2466,7 @@ namespace dpct
|
|
|
2494
2466
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2495
2467
|
library_data_t::real_float, library_data_t::real_float):
|
|
2496
2468
|
{
|
|
2497
|
-
detail::gemm_batch_impl<oneapi::
|
|
2469
|
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
|
2498
2470
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
2499
2471
|
break;
|
|
2500
2472
|
}
|
|
@@ -2570,7 +2542,7 @@ namespace dpct
|
|
|
2570
2542
|
/// \param [in] stride_c Stride between the different C matrices.
|
|
2571
2543
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
2572
2544
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
2573
|
-
inline void gemm_batch(sycl::queue & q, oneapi::
|
|
2545
|
+
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
|
2574
2546
|
int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
|
|
2575
2547
|
long long int stride_a, const void * b, library_data_t b_type, int ldb,
|
|
2576
2548
|
long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
@@ -2643,7 +2615,7 @@ namespace dpct
|
|
|
2643
2615
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2644
2616
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2645
2617
|
{
|
|
2646
|
-
detail::gemm_batch_impl<oneapi::
|
|
2618
|
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
|
2647
2619
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
2648
2620
|
batch_size);
|
|
2649
2621
|
break;
|
|
@@ -2652,7 +2624,7 @@ namespace dpct
|
|
|
2652
2624
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2653
2625
|
library_data_t::real_float, library_data_t::real_float):
|
|
2654
2626
|
{
|
|
2655
|
-
detail::gemm_batch_impl<oneapi::
|
|
2627
|
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
|
2656
2628
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
2657
2629
|
batch_size);
|
|
2658
2630
|
break;
|
|
@@ -3025,6 +2997,778 @@ namespace dpct
|
|
|
3025
2997
|
return 0;
|
|
3026
2998
|
}
|
|
3027
2999
|
|
|
3000
|
+
template <int n_nondefault_params, int n_default_params, typename T>
|
|
3001
|
+
class args_selector;
|
|
3002
|
+
|
|
3003
|
+
/// args_selector is a helper class for extracting arguments from an
|
|
3004
|
+
/// array of pointers to arguments or buffer of arguments to pass to a
|
|
3005
|
+
/// kernel function.
|
|
3006
|
+
///
|
|
3007
|
+
/// \param R(Ts...) The type of the kernel
|
|
3008
|
+
/// \param n_nondefault_params The number of nondefault parameters of the
|
|
3009
|
+
/// kernel (excluding parameters that like sycl::nd_item, etc.) \param
|
|
3010
|
+
/// n_default_params The number of default parameters of the kernel
|
|
3011
|
+
///
|
|
3012
|
+
/// Example usage:
|
|
3013
|
+
/// With the following kernel:
|
|
3014
|
+
/// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
|
|
3015
|
+
/// f=.1) {}
|
|
3016
|
+
/// and with the declaration:
|
|
3017
|
+
/// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
|
|
3018
|
+
/// we have:
|
|
3019
|
+
/// selector.get<0>() returns a reference to sycl::float*,
|
|
3020
|
+
/// selector.get<1>() returns a reference to int,
|
|
3021
|
+
/// selector.get<2>() returns a reference to float
|
|
3022
|
+
template <int n_nondefault_params, int n_default_params, typename R,
|
|
3023
|
+
typename... Ts>
|
|
3024
|
+
class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
|
|
3025
|
+
private:
|
|
3026
|
+
void **kernel_params;
|
|
3027
|
+
char *args_buffer;
|
|
3028
|
+
|
|
3029
|
+
template <int i> static constexpr int account_for_default_params() {
|
|
3030
|
+
constexpr int n_total_params = sizeof...(Ts);
|
|
3031
|
+
if constexpr (i >= n_nondefault_params) {
|
|
3032
|
+
return n_total_params - n_default_params +
|
|
3033
|
+
(i - n_nondefault_params);
|
|
3034
|
+
} else {
|
|
3035
|
+
return i;
|
|
3036
|
+
}
|
|
3037
|
+
}
|
|
3038
|
+
|
|
3039
|
+
public:
|
|
3040
|
+
/// Get the type of the ith argument of R(Ts...)
|
|
3041
|
+
/// \param [in] i Index of parameter to get
|
|
3042
|
+
/// \returns Type of ith parameter
|
|
3043
|
+
template <int i>
|
|
3044
|
+
using arg_type = std::tuple_element_t<account_for_default_params<i>(),
|
|
3045
|
+
std::tuple<Ts...>>;
|
|
3046
|
+
static constexpr int params_num = sizeof...(Ts);
|
|
3047
|
+
|
|
3048
|
+
private:
|
|
3049
|
+
template <int i> static constexpr int get_offset() {
|
|
3050
|
+
if constexpr (i == 0) {
|
|
3051
|
+
// we can assume args_buffer is properly aligned to the
|
|
3052
|
+
// first argument
|
|
3053
|
+
return 0;
|
|
3054
|
+
} else {
|
|
3055
|
+
constexpr int prev_off = get_offset<i - 1>();
|
|
3056
|
+
constexpr int prev_past_end =
|
|
3057
|
+
prev_off + sizeof(arg_type<i - 1>);
|
|
3058
|
+
using T = arg_type<i>;
|
|
3059
|
+
// is the past-the-end of the i-1st element properly aligned
|
|
3060
|
+
// with the ith element's alignment?
|
|
3061
|
+
if constexpr (prev_past_end % alignof(T) == 0) {
|
|
3062
|
+
return prev_past_end;
|
|
3063
|
+
}
|
|
3064
|
+
// otherwise bump prev_past_end to match alignment
|
|
3065
|
+
else {
|
|
3066
|
+
return prev_past_end +
|
|
3067
|
+
(alignof(T) - (prev_past_end % alignof(T)));
|
|
3068
|
+
}
|
|
3069
|
+
}
|
|
3070
|
+
}
|
|
3071
|
+
|
|
3072
|
+
static char *get_args_buffer(void **extra) {
|
|
3073
|
+
if (!extra)
|
|
3074
|
+
return nullptr;
|
|
3075
|
+
for (; (std::size_t)*extra != 0; ++extra) {
|
|
3076
|
+
if ((std::size_t)*extra == 1) {
|
|
3077
|
+
return static_cast<char *>(*(extra + 1));
|
|
3078
|
+
}
|
|
3079
|
+
}
|
|
3080
|
+
return nullptr;
|
|
3081
|
+
}
|
|
3082
|
+
|
|
3083
|
+
public:
|
|
3084
|
+
/// If kernel_params is nonnull, then args_selector will
|
|
3085
|
+
/// extract arguments from kernel_params. Otherwise, it
|
|
3086
|
+
/// will extract them from extra.
|
|
3087
|
+
/// \param [in] kernel_params Array of pointers to arguments
|
|
3088
|
+
/// a or null pointer.
|
|
3089
|
+
/// \param [in] extra Array containing pointer to argument buffer.
|
|
3090
|
+
args_selector(void **kernel_params, void **extra)
|
|
3091
|
+
: kernel_params(kernel_params),
|
|
3092
|
+
args_buffer(get_args_buffer(extra)) {}
|
|
3093
|
+
|
|
3094
|
+
/// Get a reference to the ith argument extracted from kernel_params
|
|
3095
|
+
/// or extra.
|
|
3096
|
+
/// \param [in] i Index of argument to get
|
|
3097
|
+
/// \returns Reference to the ith argument
|
|
3098
|
+
template <int i> arg_type<i> &get() {
|
|
3099
|
+
if (kernel_params) {
|
|
3100
|
+
return *static_cast<arg_type<i> *>(kernel_params[i]);
|
|
3101
|
+
} else {
|
|
3102
|
+
return *reinterpret_cast<arg_type<i> *>(args_buffer +
|
|
3103
|
+
get_offset<i>());
|
|
3104
|
+
}
|
|
3105
|
+
}
|
|
3106
|
+
}; // COPY from DPCT head file
|
|
3107
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
|
|
3108
|
+
|
|
3109
|
+
/// Utility class for launching SYCL kernels through kernel
|
|
3110
|
+
/// function wrapper.
|
|
3111
|
+
/// For example:
|
|
3112
|
+
/// A SYCL kernel function:
|
|
3113
|
+
/// void kernel_func(int *ptr, sycl::nd_item<3> item);
|
|
3114
|
+
/// Kernel function wrapper:
|
|
3115
|
+
/// void kernel_func_wrapper(int *ptr) {
|
|
3116
|
+
/// sycl::queue queue = *dpct::kernel_launcher::_que;
|
|
3117
|
+
/// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
|
|
3118
|
+
/// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
|
|
3119
|
+
/// queue.parallel_for(
|
|
3120
|
+
/// nr,
|
|
3121
|
+
/// [=](sycl::nd_item<3> item_ct1) {
|
|
3122
|
+
/// kernel_func(ptr, item_ct1);
|
|
3123
|
+
/// });
|
|
3124
|
+
/// }
|
|
3125
|
+
/// Then launch the kernel through wrapper like:
|
|
3126
|
+
/// typedef void(*fpt)(int *);
|
|
3127
|
+
/// fpt fp = kernel_func_wrapper;
|
|
3128
|
+
/// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
|
|
3129
|
+
/// device_ptr);
|
|
3130
|
+
/// If the origin function type is erased, then need to register it first:
|
|
3131
|
+
/// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
|
|
3132
|
+
/// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
|
|
3133
|
+
/// 0, 0);
|
|
3134
|
+
class kernel_launcher {
|
|
3135
|
+
template <typename FuncT, typename ArgSelector, std::size_t... Index>
|
|
3136
|
+
static void launch_helper(FuncT &&func, ArgSelector &selector,
|
|
3137
|
+
std::index_sequence<Index...>) {
|
|
3138
|
+
func(selector.template get<Index>()...);
|
|
3139
|
+
}
|
|
3140
|
+
static void set_execution_config(dim3 group_range, dim3 local_range,
|
|
3141
|
+
unsigned int local_mem_size,
|
|
3142
|
+
queue_ptr que) {
|
|
3143
|
+
if (que) {
|
|
3144
|
+
_que = que;
|
|
3145
|
+
} else {
|
|
3146
|
+
_que = &get_default_queue();
|
|
3147
|
+
}
|
|
3148
|
+
_nr = sycl::nd_range<3>(
|
|
3149
|
+
static_cast<sycl::range<3>>(group_range * local_range),
|
|
3150
|
+
static_cast<sycl::range<3>>(local_range));
|
|
3151
|
+
_local_mem_size = local_mem_size;
|
|
3152
|
+
|
|
3153
|
+
|
|
3154
|
+
};
|
|
3155
|
+
static inline std::mutex kernel_function_ptr_map_mutex;
|
|
3156
|
+
|
|
3157
|
+
public:
|
|
3158
|
+
/// Variables for storing execution configuration.
|
|
3159
|
+
static inline thread_local sycl::queue *_que = nullptr;
|
|
3160
|
+
static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
|
|
3161
|
+
static inline thread_local unsigned int _local_mem_size = 0;
|
|
3162
|
+
/// Map for retrieving launchable functor from a raw pointer.
|
|
3163
|
+
static inline std::map<
|
|
3164
|
+
const void *,
|
|
3165
|
+
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
|
|
3166
|
+
kernel_function_ptr_map = {};
|
|
3167
|
+
|
|
3168
|
+
/// Registers a kernel function pointer with a corresponding launchable
|
|
3169
|
+
/// functor.
|
|
3170
|
+
/// \param [in] func Pointer to the kernel function.
|
|
3171
|
+
/// \param [in] launcher Functor to handle kernel invocation.
|
|
3172
|
+
static void register_kernel_ptr(
|
|
3173
|
+
const void *func,
|
|
3174
|
+
std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
|
|
3175
|
+
launcher) {
|
|
3176
|
+
std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
|
|
3177
|
+
kernel_function_ptr_map[func] = std::move(launcher);
|
|
3178
|
+
}
|
|
3179
|
+
/// Launches a kernel function with arguments provided directly through
|
|
3180
|
+
/// kernel function wrapper.
|
|
3181
|
+
/// \tparam FuncT Type of the kernel function wrapper.
|
|
3182
|
+
/// \tparam ArgsT Types of kernel arguments.
|
|
3183
|
+
/// \param [in] func Pointer to the kernel function wrapper.
|
|
3184
|
+
/// \param [in] group_range SYCL group range.
|
|
3185
|
+
/// \param [in] local_range SYCL local range.
|
|
3186
|
+
/// \param [in] local_mem_size The size of local memory required by the
|
|
3187
|
+
/// kernel function. \param [in] que SYCL queue used to execute kernel.
|
|
3188
|
+
/// \param [in] args Kernel arguments.
|
|
3189
|
+
template <typename FuncT, typename... ArgsT>
|
|
3190
|
+
static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
|
|
3191
|
+
launch(FuncT *func, dim3 group_range, dim3 local_range,
|
|
3192
|
+
unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
|
|
3193
|
+
set_execution_config(group_range, local_range, local_mem_size, que);
|
|
3194
|
+
func(args...);
|
|
3195
|
+
}
|
|
3196
|
+
/// Launches a kernel function through registered kernel function
|
|
3197
|
+
/// wrapper. \param [in] func Pointer to the registered kernel function
|
|
3198
|
+
/// wrapper. \param [in] group_range SYCL group range. \param [in]
|
|
3199
|
+
/// local_range SYCL local range. \param [in] args Array of pointers to
|
|
3200
|
+
/// kernel arguments. \param [in] local_mem_size The size of local
|
|
3201
|
+
/// memory required by the kernel function. \param [in] que SYCL queue
|
|
3202
|
+
/// used to execute kernel.
|
|
3203
|
+
static void launch(const void *func, dim3 group_range, dim3 local_range,
|
|
3204
|
+
void **args, unsigned int local_mem_size,
|
|
3205
|
+
queue_ptr que) {
|
|
3206
|
+
std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
|
|
3207
|
+
auto Iter = kernel_function_ptr_map.find(func);
|
|
3208
|
+
if (Iter == kernel_function_ptr_map.end()) {
|
|
3209
|
+
throw std::runtime_error("dpct::launch() : no registered "
|
|
3210
|
+
"kernel function wrapper found.");
|
|
3211
|
+
}
|
|
3212
|
+
(Iter->second)(group_range, local_range, args, local_mem_size, que);
|
|
3213
|
+
}
|
|
3214
|
+
/// Launches a kernel function with packed arguments through kernel
|
|
3215
|
+
/// function wrapper.
|
|
3216
|
+
/// \tparam FuncT Type of the kernel function wrapper.
|
|
3217
|
+
/// \param [in] func Pointer to the kernel function wrapper.
|
|
3218
|
+
/// \param [in] group_range SYCL group range.
|
|
3219
|
+
/// \param [in] local_range SYCL local range.
|
|
3220
|
+
/// \param [in] args Array of pointers to kernel arguments.
|
|
3221
|
+
/// \param [in] local_mem_size The size of local memory required by the
|
|
3222
|
+
/// kernel function. \param [in] que SYCL queue used to execute kernel.
|
|
3223
|
+
template <typename FuncT>
|
|
3224
|
+
static std::enable_if_t<std::is_function_v<FuncT>, void>
|
|
3225
|
+
launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
|
|
3226
|
+
unsigned int local_mem_size, queue_ptr que) {
|
|
3227
|
+
constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
|
|
3228
|
+
set_execution_config(group_range, local_range, local_mem_size, que);
|
|
3229
|
+
args_selector<p_num, p_num, FuncT> selector(args, nullptr);
|
|
3230
|
+
launch_helper(func, selector, std::make_index_sequence<p_num>{});
|
|
3231
|
+
}
|
|
3232
|
+
}; // COPY from DPCT head file
|
|
3233
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
|
|
3234
|
+
|
|
3235
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
|
|
3236
|
+
template <typename T>
|
|
3237
|
+
T select_from_sub_group(
|
|
3238
|
+
sycl::sub_group g,
|
|
3239
|
+
T x,
|
|
3240
|
+
int remote_local_id,
|
|
3241
|
+
int logical_sub_group_size = 32) {
|
|
3242
|
+
unsigned int start_index = g.get_local_linear_id() /
|
|
3243
|
+
logical_sub_group_size *
|
|
3244
|
+
logical_sub_group_size;
|
|
3245
|
+
return sycl::select_from_group(
|
|
3246
|
+
g, x, start_index + remote_local_id % logical_sub_group_size);
|
|
3247
|
+
}
|
|
3248
|
+
|
|
3249
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
|
|
3250
|
+
template <typename T>
|
|
3251
|
+
void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
|
|
3252
|
+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
|
|
3253
|
+
int lane = sg.get_local_linear_id();
|
|
3254
|
+
|
|
3255
|
+
int lane_group8_row = lane / 8;
|
|
3256
|
+
int lane_group8_col = lane % 8;
|
|
3257
|
+
|
|
3258
|
+
if (!trans) {
|
|
3259
|
+
// calculate the source lane
|
|
3260
|
+
int src_lane = 2 * lane_group8_row;
|
|
3261
|
+
if (lane_group8_col >= 4)
|
|
3262
|
+
src_lane += 1;
|
|
3263
|
+
|
|
3264
|
+
// Broadcast the address from the source lane
|
|
3265
|
+
auto recv_addr_uintp =
|
|
3266
|
+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
|
|
3267
|
+
|
|
3268
|
+
// Cast the received address from uintptr_t to the type of 'm'
|
|
3269
|
+
auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
|
|
3270
|
+
|
|
3271
|
+
// Non-transposed load
|
|
3272
|
+
*m = recv_addr[lane_group8_col % 4];
|
|
3273
|
+
} else {
|
|
3274
|
+
// calculate the source lane
|
|
3275
|
+
int src_lane = (lane % 4) * 2;
|
|
3276
|
+
|
|
3277
|
+
// Broadcast the address from the source lane
|
|
3278
|
+
auto recv_addr_uintp_1 =
|
|
3279
|
+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
|
|
3280
|
+
auto recv_addr_uintp_2 =
|
|
3281
|
+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
|
|
3282
|
+
|
|
3283
|
+
// Cast the received address from uintptr_t to 'half *'
|
|
3284
|
+
auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
|
|
3285
|
+
auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
|
|
3286
|
+
|
|
3287
|
+
// Transposed load
|
|
3288
|
+
int index = lane / 4;
|
|
3289
|
+
sycl::half val0 = recv_addr_1[index];
|
|
3290
|
+
sycl::half val1 = recv_addr_2[index];
|
|
3291
|
+
|
|
3292
|
+
// Combine the two 16-bits into one 32-bit value
|
|
3293
|
+
sycl::half2 val = sycl::half2(val0, val1);
|
|
3294
|
+
*m = *reinterpret_cast<T*>(&val);
|
|
3295
|
+
}
|
|
3296
|
+
}
|
|
3297
|
+
|
|
3298
|
+
template <typename T>
|
|
3299
|
+
void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
|
|
3300
|
+
// Load 1st matrix
|
|
3301
|
+
ldmatrix(addr, m1, trans, 0);
|
|
3302
|
+
// Load 2nd matrix
|
|
3303
|
+
ldmatrix(addr, m2, trans, 1);
|
|
3304
|
+
}
|
|
3305
|
+
|
|
3306
|
+
template <typename T>
|
|
3307
|
+
void ldmatrix(
|
|
3308
|
+
uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
|
|
3309
|
+
// Load 1st matrix
|
|
3310
|
+
ldmatrix(addr, m1, trans, 0);
|
|
3311
|
+
// Load 2nd matrix
|
|
3312
|
+
ldmatrix(addr, m2, trans, 1);
|
|
3313
|
+
// Load 3rd matrix
|
|
3314
|
+
ldmatrix(addr, m3, trans, 2);
|
|
3315
|
+
// Load 4th matrix
|
|
3316
|
+
ldmatrix(addr, m4, trans, 3);
|
|
3317
|
+
}
|
|
3318
|
+
|
|
3319
|
+
// /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
|
|
3320
|
+
|
|
3321
|
+
/// A helper struct that defines the pack type for the input matrix
|
|
3322
|
+
/// fragments
|
|
3323
|
+
/// of mma() function based on the type of input matrix fragments.
|
|
3324
|
+
/// The MMAType struct is specialized for different types of input matrices.
|
|
3325
|
+
/// Currently, the specialization for f16, bf16 and s8 types is defined
|
|
3326
|
+
/// below. \tparam [in] T The type of the input matrix fragments
|
|
3327
|
+
template <typename T>
|
|
3328
|
+
struct MMAType {
|
|
3329
|
+
using PackType = uint32_t;
|
|
3330
|
+
};
|
|
3331
|
+
|
|
3332
|
+
/// Each work item of a sub-group (limited to size 32) calling this function
|
|
3333
|
+
/// calculates a subset fragment for the output matrix D using MAD operation
|
|
3334
|
+
/// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
|
|
3335
|
+
/// types:
|
|
3336
|
+
/// - m8n8k4 (f32.f16.f16.f32)
|
|
3337
|
+
/// - m8n8k16 (s32.s8.s8.s32)
|
|
3338
|
+
/// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
|
|
3339
|
+
/// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
|
|
3340
|
+
/// - m16n8k32 (s32.s8.s8.s32)
|
|
3341
|
+
/// Here, m, n & k define the shapes of A, B & C matrices respectively
|
|
3342
|
+
/// (A = [m x k], B = [k x n], C = [m x n]).
|
|
3343
|
+
/// \tparam [in] M The rows of A, C & D matrices
|
|
3344
|
+
/// \tparam [in] N The columns of B, C, D matrices
|
|
3345
|
+
/// \tparam [in] K The columns & rows of A & B matrices respectively
|
|
3346
|
+
/// \tparam [in] ABType The type of the input matrix (A & B) fragment
|
|
3347
|
+
/// \tparam [in] CDType The type of the output matrix (C & D) fragment
|
|
3348
|
+
/// \param [out] d_mat_frag The fragment of the output matrix D to store the
|
|
3349
|
+
/// result of A * B + C
|
|
3350
|
+
/// \param [in] a_mat_frag The fragment of the input matrix A to be
|
|
3351
|
+
/// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
|
|
3352
|
+
/// the input matrix B to be multiplied with A matrix fragment \param [in]
|
|
3353
|
+
/// c_mat_frag The fragment of the input matrix C to be added with the
|
|
3354
|
+
/// result of A * B fragments
|
|
3355
|
+
template <int M, int N, int K, typename ABType, typename CDType>
|
|
3356
|
+
void mma(
|
|
3357
|
+
volatile void** d_mat_frag,
|
|
3358
|
+
void* a_mat_frag,
|
|
3359
|
+
void* b_mat_frag,
|
|
3360
|
+
void* c_mat_frag) {
|
|
3361
|
+
auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
|
|
3362
|
+
auto a =
|
|
3363
|
+
reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
|
|
3364
|
+
auto b =
|
|
3365
|
+
reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
|
|
3366
|
+
auto c = reinterpret_cast<CDType*>(c_mat_frag);
|
|
3367
|
+
|
|
3368
|
+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
|
|
3369
|
+
int lane = sg.get_local_linear_id();
|
|
3370
|
+
|
|
3371
|
+
static_assert(
|
|
3372
|
+
(M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
|
|
3373
|
+
(M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
|
|
3374
|
+
(M == 16 && N == 8 && K == 32),
|
|
3375
|
+
"Unsupported MMA shape!");
|
|
3376
|
+
|
|
3377
|
+
short row_load_offset = 4 * (lane >> 2);
|
|
3378
|
+
short col_load_offset = 8 * (lane % 4);
|
|
3379
|
+
|
|
3380
|
+
if constexpr (M == 8 && N == 8 && K == 4) {
|
|
3381
|
+
if constexpr (std::is_floating_point_v<CDType>) {
|
|
3382
|
+
col_load_offset = row_load_offset % 16;
|
|
3383
|
+
|
|
3384
|
+
// Init D matrix with fragments of C matrix
|
|
3385
|
+
*d[0] = c[0];
|
|
3386
|
+
*d[1] = c[1];
|
|
3387
|
+
*d[2] = c[2];
|
|
3388
|
+
*d[3] = c[3];
|
|
3389
|
+
*d[4] = c[4];
|
|
3390
|
+
*d[5] = c[5];
|
|
3391
|
+
*d[6] = c[6];
|
|
3392
|
+
*d[7] = c[7];
|
|
3393
|
+
|
|
3394
|
+
// Calculate the row and col offset indices to iterate through the row
|
|
3395
|
+
// & col fragments of A & B matrices
|
|
3396
|
+
int r_ind = (lane % 2) ? 1 : 0;
|
|
3397
|
+
int c_ind = ((lane % 4) / 2) ? 2 : 0;
|
|
3398
|
+
|
|
3399
|
+
// Each sub-group is responsible for computing a fragment size of 8*8
|
|
3400
|
+
// elements of matrix D for each of 4 MMA computations.
|
|
3401
|
+
// Each work item computes 8 elements of matrix D by gathering
|
|
3402
|
+
// their corresponding col & row matrix fragments of length k (4)
|
|
3403
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3404
|
+
// row0 = (i % 4) if (lane < 16) else (i % 4) + 4
|
|
3405
|
+
// col0 = (lane % 4)
|
|
3406
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3407
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3408
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3409
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3410
|
+
|
|
3411
|
+
for (int i = 0; i < 4; i++) {
|
|
3412
|
+
// Load partial fragment from col0 of matrix A ({a0, a1})
|
|
3413
|
+
recv_a[0] =
|
|
3414
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3415
|
+
// Load partial fragment from col0 of matrix A ({a2, a3})
|
|
3416
|
+
recv_a[1] =
|
|
3417
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3418
|
+
|
|
3419
|
+
// Load partial fragment from row0 of matrix B ({b0, b1})
|
|
3420
|
+
recv_b[0] =
|
|
3421
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3422
|
+
// Load partial fragment from row0 of matrix B ({b2, b3})
|
|
3423
|
+
recv_b[1] =
|
|
3424
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
|
3425
|
+
|
|
3426
|
+
auto ra = reinterpret_cast<ABType*>(recv_a);
|
|
3427
|
+
auto rb = reinterpret_cast<ABType*>(recv_b);
|
|
3428
|
+
|
|
3429
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3430
|
+
// fragments and adds it to the corresponding D matrix fragment (for
|
|
3431
|
+
// even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
|
|
3432
|
+
// a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
|
|
3433
|
+
// * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
|
|
3434
|
+
// b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
|
|
3435
|
+
// d3 += col1{ a3 } * row0{ b3 }
|
|
3436
|
+
*d[0] +=
|
|
3437
|
+
static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
|
|
3438
|
+
*d[1] += static_cast<float>(ra[r_ind]) *
|
|
3439
|
+
static_cast<float>(rb[c_ind + 1]);
|
|
3440
|
+
*d[2] += static_cast<float>(ra[r_ind + 2]) *
|
|
3441
|
+
static_cast<float>(rb[c_ind]);
|
|
3442
|
+
*d[3] += static_cast<float>(ra[r_ind + 2]) *
|
|
3443
|
+
static_cast<float>(rb[c_ind + 1]);
|
|
3444
|
+
|
|
3445
|
+
// Load partial fragment from row1 of matrix B ({b0, b1})
|
|
3446
|
+
recv_b[0] =
|
|
3447
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
|
|
3448
|
+
// Load partial fragment from row1 of matrix B ({b2, b3})
|
|
3449
|
+
recv_b[1] =
|
|
3450
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
|
|
3451
|
+
|
|
3452
|
+
// (for even work item indices)
|
|
3453
|
+
// d0 += col0{ a0 } * row1{ b0 }
|
|
3454
|
+
// d1 += col0{ a0 } * row1{ b1 }
|
|
3455
|
+
// d2 += col1{ a2 } * row1{ b0 }
|
|
3456
|
+
// d3 += col1{ a2 } * row1{ b1 }
|
|
3457
|
+
// (for odd work item indices)
|
|
3458
|
+
// d0 += col0{ a1 } * row1{ b2 }
|
|
3459
|
+
// d1 += col0{ a1 } * row1{ b3 }
|
|
3460
|
+
// d2 += col1{ a3 } * row1{ b2 }
|
|
3461
|
+
// d3 += col1{ a3 } * row1{ b3 }
|
|
3462
|
+
*d[4] +=
|
|
3463
|
+
static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
|
|
3464
|
+
*d[5] += static_cast<float>(ra[r_ind]) *
|
|
3465
|
+
static_cast<float>(rb[c_ind + 1]);
|
|
3466
|
+
*d[6] += static_cast<float>(ra[r_ind + 2]) *
|
|
3467
|
+
static_cast<float>(rb[c_ind]);
|
|
3468
|
+
*d[7] += static_cast<float>(ra[r_ind + 2]) *
|
|
3469
|
+
static_cast<float>(rb[c_ind + 1]);
|
|
3470
|
+
}
|
|
3471
|
+
}
|
|
3472
|
+
} else if constexpr (M == 8 && N == 8 && K == 16) {
|
|
3473
|
+
if constexpr (std::is_integral_v<ABType>) {
|
|
3474
|
+
// Init D matrix with fragments of C matrix
|
|
3475
|
+
*d[0] = c[0];
|
|
3476
|
+
*d[1] = c[1];
|
|
3477
|
+
|
|
3478
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3479
|
+
// elements of matrix D.
|
|
3480
|
+
// Each work item computes 2 elements of matrix D by gathering
|
|
3481
|
+
// their corresponding row & col matrix fragments of length k (16)
|
|
3482
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3483
|
+
// row0 = ((lane % 4) * 4) + i
|
|
3484
|
+
// col0 = (lane >> 2)
|
|
3485
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3486
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3487
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3488
|
+
for (int i = 0; i < 4; i++) {
|
|
3489
|
+
typename MMAType<ABType>::PackType recv_a, recv_b[2];
|
|
3490
|
+
|
|
3491
|
+
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
|
3492
|
+
recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3493
|
+
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
|
3494
|
+
recv_b[0] =
|
|
3495
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3496
|
+
// Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
|
|
3497
|
+
recv_b[1] =
|
|
3498
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
|
3499
|
+
|
|
3500
|
+
auto a = reinterpret_cast<ABType*>(&recv_a);
|
|
3501
|
+
auto b = reinterpret_cast<ABType*>(recv_b);
|
|
3502
|
+
|
|
3503
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3504
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3505
|
+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
|
3506
|
+
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
|
|
3507
|
+
// a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
|
|
3508
|
+
// col1{ b0, b1, b2, b3 }
|
|
3509
|
+
for (int j = 0; j < 4; j++) {
|
|
3510
|
+
*d[0] += a[j] * b[j];
|
|
3511
|
+
*d[1] += a[j] * b[j + 4];
|
|
3512
|
+
}
|
|
3513
|
+
}
|
|
3514
|
+
}
|
|
3515
|
+
} else if constexpr (M == 16 && N == 8 && K == 8) {
|
|
3516
|
+
if constexpr (std::is_floating_point_v<CDType>) {
|
|
3517
|
+
// Init D matrix fragment with C matrix fragment
|
|
3518
|
+
*d[0] = c[0];
|
|
3519
|
+
*d[1] = c[1];
|
|
3520
|
+
*d[2] = c[2];
|
|
3521
|
+
*d[3] = c[3];
|
|
3522
|
+
|
|
3523
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3524
|
+
// elements of matrix D.
|
|
3525
|
+
// Each work item computes 4 elements of matrix D by gathering
|
|
3526
|
+
// their corresponding row & col matrix fragments of length k (8)
|
|
3527
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3528
|
+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
|
3529
|
+
// col0 = (lane % 4) * 2 + (i & 0x1)
|
|
3530
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3531
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3532
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3533
|
+
for (int i = 0; i < 4; i++) {
|
|
3534
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3535
|
+
|
|
3536
|
+
// Load partial fragment from row0 of matrix A ({a0, a1})
|
|
3537
|
+
recv_a[0] =
|
|
3538
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3539
|
+
// Load partial fragment from row1 of matrix A ({a2, a3})
|
|
3540
|
+
recv_a[1] =
|
|
3541
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3542
|
+
// Load partial fragment from col0 of matrix B ({b0, b1})
|
|
3543
|
+
recv_b[0] =
|
|
3544
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3545
|
+
// Load partial fragment from col1 of matrix B ({b0, b1})
|
|
3546
|
+
recv_b[1] =
|
|
3547
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
|
3548
|
+
|
|
3549
|
+
auto ra = reinterpret_cast<ABType*>(recv_a);
|
|
3550
|
+
auto rb = reinterpret_cast<ABType*>(recv_b);
|
|
3551
|
+
|
|
3552
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3553
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3554
|
+
// += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
|
|
3555
|
+
// b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
|
|
3556
|
+
// } * col1{ b0, b1 }
|
|
3557
|
+
for (int j = 0; j < 2; j++) {
|
|
3558
|
+
*d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
|
|
3559
|
+
*d[1] +=
|
|
3560
|
+
static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
|
|
3561
|
+
*d[2] +=
|
|
3562
|
+
static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
|
|
3563
|
+
*d[3] +=
|
|
3564
|
+
static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
|
|
3565
|
+
}
|
|
3566
|
+
}
|
|
3567
|
+
}
|
|
3568
|
+
} else if constexpr (M == 16 && N == 8 && K == 16) {
|
|
3569
|
+
if constexpr (std::is_floating_point_v<CDType>) {
|
|
3570
|
+
// Init D matrix fragment with C matrix fragment
|
|
3571
|
+
*d[0] = c[0];
|
|
3572
|
+
*d[1] = c[1];
|
|
3573
|
+
*d[2] = c[2];
|
|
3574
|
+
*d[3] = c[3];
|
|
3575
|
+
|
|
3576
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3577
|
+
// elements of matrix D.
|
|
3578
|
+
// Each work item computes 4 elements of matrix D by gathering
|
|
3579
|
+
// their corresponding row & col matrix fragments of length k (8)
|
|
3580
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3581
|
+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
|
3582
|
+
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
|
|
3583
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3584
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3585
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3586
|
+
for (int i = 0; i < 4; i++) {
|
|
3587
|
+
typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
|
|
3588
|
+
|
|
3589
|
+
// Load partial fragment from row0 of matrix A ({a0, a1})
|
|
3590
|
+
recv_a[0] =
|
|
3591
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3592
|
+
// Load partial fragment from row0 of matrix A ({a2, a3})
|
|
3593
|
+
recv_a[1] =
|
|
3594
|
+
dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
|
|
3595
|
+
// Load partial fragment from row1 of matrix A ({a0, a1})
|
|
3596
|
+
recv_a[2] =
|
|
3597
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3598
|
+
// Load partial fragment from row1 of matrix A ({a2, a3})
|
|
3599
|
+
recv_a[3] =
|
|
3600
|
+
dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
|
|
3601
|
+
|
|
3602
|
+
// Load partial fragment from col0 of matrix B ({b0, b1})
|
|
3603
|
+
recv_b[0] =
|
|
3604
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3605
|
+
// Load partial fragment from col0 of matrix B ({b2, b3})
|
|
3606
|
+
recv_b[1] =
|
|
3607
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
|
3608
|
+
// Load partial fragment from col1 of matrix B ({b0, b1})
|
|
3609
|
+
recv_b[2] =
|
|
3610
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
|
|
3611
|
+
// Load partial fragment from col1 of matrix B ({b2, b3})
|
|
3612
|
+
recv_b[3] =
|
|
3613
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
|
|
3614
|
+
|
|
3615
|
+
auto ra = reinterpret_cast<ABType*>(recv_a);
|
|
3616
|
+
auto rb = reinterpret_cast<ABType*>(recv_b);
|
|
3617
|
+
|
|
3618
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3619
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3620
|
+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
|
3621
|
+
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
|
|
3622
|
+
// a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
|
|
3623
|
+
// col1{ b0, b1, b2, b3 }
|
|
3624
|
+
for (int j = 0; j < 4; j++) {
|
|
3625
|
+
*d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
|
|
3626
|
+
*d[1] +=
|
|
3627
|
+
static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
|
|
3628
|
+
*d[2] +=
|
|
3629
|
+
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
|
|
3630
|
+
*d[3] += static_cast<CDType>(ra[j + 4]) *
|
|
3631
|
+
static_cast<CDType>(rb[j + 4]);
|
|
3632
|
+
}
|
|
3633
|
+
}
|
|
3634
|
+
} else if constexpr (std::is_integral_v<ABType>) {
|
|
3635
|
+
// Init D matrix with fragments of C matrix
|
|
3636
|
+
*d[0] = c[0];
|
|
3637
|
+
*d[1] = c[1];
|
|
3638
|
+
*d[2] = c[2];
|
|
3639
|
+
*d[3] = c[3];
|
|
3640
|
+
|
|
3641
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3642
|
+
// elements of matrix D.
|
|
3643
|
+
// Each work item computes 4 elements of matrix D by gathering
|
|
3644
|
+
// their corresponding row & col matrix fragments of length k (8)
|
|
3645
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3646
|
+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
|
3647
|
+
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
|
|
3648
|
+
// As each row & col fragment of A & B matrices is distributed across
|
|
3649
|
+
// 4 work items, each iteration of below loop loads a partial fragment
|
|
3650
|
+
// of matrix A (row) and matrix B (col) using the row & col offsets.
|
|
3651
|
+
for (int i = 0; i < 4; i++) {
|
|
3652
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3653
|
+
|
|
3654
|
+
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
|
3655
|
+
recv_a[0] =
|
|
3656
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3657
|
+
// Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
|
|
3658
|
+
recv_a[1] =
|
|
3659
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3660
|
+
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
|
3661
|
+
recv_b[0] =
|
|
3662
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3663
|
+
// Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
|
|
3664
|
+
recv_b[1] =
|
|
3665
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
|
3666
|
+
|
|
3667
|
+
auto ra = reinterpret_cast<ABType*>(recv_a);
|
|
3668
|
+
auto rb = reinterpret_cast<ABType*>(recv_b);
|
|
3669
|
+
|
|
3670
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3671
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3672
|
+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
|
3673
|
+
// a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
|
|
3674
|
+
// a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
|
|
3675
|
+
// col1{ b4, b5, b6, b7 }
|
|
3676
|
+
for (int i = 0; i < 4; i++) {
|
|
3677
|
+
*d[0] += ra[i] * rb[i];
|
|
3678
|
+
*d[1] += ra[i] * rb[i + 4];
|
|
3679
|
+
*d[2] += ra[i + 4] * rb[i];
|
|
3680
|
+
*d[3] += ra[i + 4] * rb[i + 4];
|
|
3681
|
+
}
|
|
3682
|
+
}
|
|
3683
|
+
}
|
|
3684
|
+
} else if constexpr (M == 16 && N == 8 && K == 32) {
|
|
3685
|
+
if constexpr (std::is_integral_v<ABType>) {
|
|
3686
|
+
// Init D matrix with fragments of C matrix
|
|
3687
|
+
*d[0] = c[0];
|
|
3688
|
+
*d[1] = c[1];
|
|
3689
|
+
*d[2] = c[2];
|
|
3690
|
+
*d[3] = c[3];
|
|
3691
|
+
|
|
3692
|
+
// Each sub-group is responsible for computing a fragment size of 16*8
|
|
3693
|
+
// elements of matrix D.
|
|
3694
|
+
// Each work item computes 4 elements of matrix D by gathering
|
|
3695
|
+
// their corresponding row & col matrix fragments of length k (32)
|
|
3696
|
+
// from A & B matrices respectively using below mapping logic:
|
|
3697
|
+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
|
|
3698
|
+
// col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
|
|
3699
|
+
// & 0x3) As each row & col fragment of A & B matrices is distributed
|
|
3700
|
+
// across 4 work items, each iteration of below loop loads a partial
|
|
3701
|
+
// fragment of matrix A (row) and matrix B (col) using the row & col
|
|
3702
|
+
// offsets.
|
|
3703
|
+
for (int i = 0; i < 4; i++) {
|
|
3704
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3705
|
+
|
|
3706
|
+
// Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
|
|
3707
|
+
recv_a[0] =
|
|
3708
|
+
dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
|
|
3709
|
+
// Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
|
|
3710
|
+
recv_a[1] =
|
|
3711
|
+
dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
|
|
3712
|
+
// Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
|
|
3713
|
+
recv_b[0] =
|
|
3714
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
|
|
3715
|
+
// Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
|
|
3716
|
+
recv_b[1] =
|
|
3717
|
+
dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
|
|
3718
|
+
|
|
3719
|
+
auto a = reinterpret_cast<ABType*>(recv_a);
|
|
3720
|
+
auto b = reinterpret_cast<ABType*>(recv_b);
|
|
3721
|
+
|
|
3722
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3723
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3724
|
+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
|
|
3725
|
+
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
|
|
3726
|
+
// a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
|
|
3727
|
+
// col1{ b0, b1, b2, b3 }
|
|
3728
|
+
for (int j = 0; j < 4; j++) {
|
|
3729
|
+
*d[0] += a[j] * b[j];
|
|
3730
|
+
*d[1] += a[j] * b[j + 4];
|
|
3731
|
+
*d[2] += a[j + 4] * b[j];
|
|
3732
|
+
*d[3] += a[j + 4] * b[j + 4];
|
|
3733
|
+
}
|
|
3734
|
+
}
|
|
3735
|
+
|
|
3736
|
+
for (int i = 0; i < 4; i++) {
|
|
3737
|
+
typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
|
|
3738
|
+
|
|
3739
|
+
// Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
|
|
3740
|
+
recv_a[0] =
|
|
3741
|
+
dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
|
|
3742
|
+
// Load partial fragment from row1 of matrix A ({a12, a13, a14,
|
|
3743
|
+
// a15})
|
|
3744
|
+
recv_a[1] =
|
|
3745
|
+
dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
|
|
3746
|
+
// Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
|
|
3747
|
+
recv_b[0] =
|
|
3748
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
|
|
3749
|
+
// Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
|
|
3750
|
+
recv_b[1] =
|
|
3751
|
+
dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
|
|
3752
|
+
|
|
3753
|
+
auto a = reinterpret_cast<ABType*>(recv_a);
|
|
3754
|
+
auto b = reinterpret_cast<ABType*>(recv_b);
|
|
3755
|
+
|
|
3756
|
+
// Each work item calculates a partial product of A & B matrix
|
|
3757
|
+
// fragments and adds it to the corresponding D matrix fragment d0
|
|
3758
|
+
// += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
|
|
3759
|
+
// a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
|
|
3760
|
+
// a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
|
|
3761
|
+
// a15 } * col1{ b4, b5, b6, b7 }
|
|
3762
|
+
for (int j = 0; j < 4; j++) {
|
|
3763
|
+
*d[0] += a[j] * b[j];
|
|
3764
|
+
*d[1] += a[j] * b[j + 4];
|
|
3765
|
+
*d[2] += a[j + 4] * b[j];
|
|
3766
|
+
*d[3] += a[j + 4] * b[j + 4];
|
|
3767
|
+
}
|
|
3768
|
+
}
|
|
3769
|
+
}
|
|
3770
|
+
}
|
|
3771
|
+
}
|
|
3028
3772
|
} // COPY from DPCT head files
|
|
3029
3773
|
|
|
3030
3774
|
#endif // GGML_SYCL_DPCT_HELPER_HPP
|