whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -595,6 +595,25 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|
|
595
595
|
}
|
|
596
596
|
}
|
|
597
597
|
|
|
598
|
+
static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
|
599
|
+
dpct::queue_ptr stream) {
|
|
600
|
+
GGML_ASSERT(ncols % QK_MXFP4 == 0);
|
|
601
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
602
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
603
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
604
|
+
|
|
605
|
+
{
|
|
606
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
607
|
+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
608
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
609
|
+
mul_mat_vec_q<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(
|
|
610
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
|
611
|
+
});
|
|
612
|
+
});
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
|
|
598
617
|
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|
599
618
|
float *dst, const int ncols,
|
|
600
619
|
const int nrows,
|
|
@@ -1123,6 +1142,9 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
|
1123
1142
|
case GGML_TYPE_IQ4_XS:
|
|
1124
1143
|
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1125
1144
|
break;
|
|
1145
|
+
case GGML_TYPE_MXFP4:
|
|
1146
|
+
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1147
|
+
break;
|
|
1126
1148
|
default:
|
|
1127
1149
|
GGML_ABORT("fatal error");
|
|
1128
1150
|
}
|
|
@@ -480,6 +480,162 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
|
480
480
|
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
|
481
481
|
}
|
|
482
482
|
|
|
483
|
+
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
484
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
|
485
|
+
|
|
486
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz
|
|
487
|
+
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x
|
|
488
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
489
|
+
|
|
490
|
+
float eps = 1e-5f;
|
|
491
|
+
std::memcpy(&eps, dst->op_params, sizeof(float));
|
|
492
|
+
if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;
|
|
493
|
+
|
|
494
|
+
const float * g_base = static_cast<const float *>(dst->src[0]->data); // dz
|
|
495
|
+
const float * x_base = static_cast<const float *>(dst->src[1]->data); // x
|
|
496
|
+
float * dx_base = static_cast< float *>(dst->data);
|
|
497
|
+
|
|
498
|
+
const int64_t D = dst->ne[0];
|
|
499
|
+
const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
|
|
500
|
+
const int64_t N = ggml_nrows(dst);
|
|
501
|
+
if (D == 0 || N == 0) return;
|
|
502
|
+
|
|
503
|
+
const ggml_tensor *G = dst->src[0];
|
|
504
|
+
const ggml_tensor *X = dst->src[1];
|
|
505
|
+
const int ts = (int) ggml_type_size(X->type);
|
|
506
|
+
GGML_ASSERT((size_t) X->nb[0] == (size_t) ts);
|
|
507
|
+
GGML_ASSERT((size_t) G->nb[0] == (size_t) ts);
|
|
508
|
+
GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
|
|
509
|
+
|
|
510
|
+
const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
|
|
511
|
+
const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
|
|
512
|
+
const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
|
|
513
|
+
|
|
514
|
+
dpct::queue_ptr q = ctx.stream();
|
|
515
|
+
|
|
516
|
+
// work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
|
|
517
|
+
const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
|
518
|
+
auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
|
|
519
|
+
int wg_cap = 256;
|
|
520
|
+
if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
|
|
521
|
+
int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
|
|
522
|
+
|
|
523
|
+
// FP32 path: per-thread compensated accumulation + hierarchical reduction
|
|
524
|
+
q->submit([&](sycl::handler &cgh) {
|
|
525
|
+
const int nwarps_loc = std::max(1, WG / WARP_SIZE);
|
|
526
|
+
// store one partial value per warp (xx and xg) for cross-warp reduction
|
|
527
|
+
auto l_xx = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
|
|
528
|
+
auto l_xg = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
|
|
529
|
+
|
|
530
|
+
cgh.parallel_for(
|
|
531
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
|
|
532
|
+
sycl::range<3>(1, 1, WG)),
|
|
533
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
534
|
+
const int row = item_ct1.get_group(2);
|
|
535
|
+
const int tid = item_ct1.get_local_id(2);
|
|
536
|
+
|
|
537
|
+
const int64_t i1 = row % n1;
|
|
538
|
+
const int64_t i2 = (row / n1) % n2;
|
|
539
|
+
const int64_t i3 = row / (n1 * n2);
|
|
540
|
+
|
|
541
|
+
const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;
|
|
542
|
+
const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
|
|
543
|
+
float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
|
|
544
|
+
|
|
545
|
+
// per-thread accumulation (compensated by default)
|
|
546
|
+
float sum_xx = 0.f, sum_xg = 0.f;
|
|
547
|
+
#ifndef GGML_SYCL_RMS_BACK_FAST
|
|
548
|
+
float c_xx = 0.f, c_xg = 0.f;
|
|
549
|
+
#endif
|
|
550
|
+
for (int64_t col = tid; col < D; col += WG) {
|
|
551
|
+
const float xv = x_row[col];
|
|
552
|
+
const float gv = g_row[col];
|
|
553
|
+
#ifdef GGML_SYCL_RMS_BACK_FAST
|
|
554
|
+
sum_xx += xv * xv;
|
|
555
|
+
sum_xg += xv * gv;
|
|
556
|
+
#else
|
|
557
|
+
float y1 = xv * xv - c_xx;
|
|
558
|
+
float t1 = sum_xx + y1;
|
|
559
|
+
c_xx = (t1 - sum_xx) - y1;
|
|
560
|
+
sum_xx = t1;
|
|
561
|
+
|
|
562
|
+
float y2 = xv * gv - c_xg;
|
|
563
|
+
float t2 = sum_xg + y2;
|
|
564
|
+
c_xg = (t2 - sum_xg) - y2;
|
|
565
|
+
sum_xg = t2;
|
|
566
|
+
#endif
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
// warp-level reduction
|
|
570
|
+
sycl::float2 xx = sycl::float2(sum_xx,
|
|
571
|
+
#ifndef GGML_SYCL_RMS_BACK_FAST
|
|
572
|
+
c_xx
|
|
573
|
+
#else
|
|
574
|
+
0.f
|
|
575
|
+
#endif
|
|
576
|
+
);
|
|
577
|
+
sycl::float2 xg = sycl::float2(sum_xg,
|
|
578
|
+
#ifndef GGML_SYCL_RMS_BACK_FAST
|
|
579
|
+
c_xg
|
|
580
|
+
#else
|
|
581
|
+
0.f
|
|
582
|
+
#endif
|
|
583
|
+
);
|
|
584
|
+
xx = warp_reduce_sum(xx, item_ct1);
|
|
585
|
+
xg = warp_reduce_sum(xg, item_ct1);
|
|
586
|
+
|
|
587
|
+
// cross-warp reduction using local memory (single barrier)
|
|
588
|
+
const auto sub_group = item_ct1.get_sub_group();
|
|
589
|
+
const auto sg_id = sub_group.get_group_linear_id();
|
|
590
|
+
const auto wi_in_sg = sub_group.get_local_linear_id();
|
|
591
|
+
const int nthreads = item_ct1.get_local_range(2);
|
|
592
|
+
const int nwarps = nthreads / WARP_SIZE;
|
|
593
|
+
|
|
594
|
+
sycl::float2 xx_total = xx;
|
|
595
|
+
sycl::float2 xg_total = xg;
|
|
596
|
+
if (nwarps > 1) {
|
|
597
|
+
if (wi_in_sg == 0) {
|
|
598
|
+
l_xx[sg_id] = xx;
|
|
599
|
+
l_xg[sg_id] = xg;
|
|
600
|
+
}
|
|
601
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
602
|
+
|
|
603
|
+
if (sg_id == 0) {
|
|
604
|
+
const unsigned wi_u = wi_in_sg;
|
|
605
|
+
sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
|
|
606
|
+
sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
|
|
607
|
+
xx_total = warp_reduce_sum(xx_first, item_ct1);
|
|
608
|
+
xg_total = warp_reduce_sum(xg_first, item_ct1);
|
|
609
|
+
} else {
|
|
610
|
+
// other subgroups keep their local totals; they'll be ignored
|
|
611
|
+
xx_total = xx;
|
|
612
|
+
xg_total = xg;
|
|
613
|
+
}
|
|
614
|
+
// ensure all threads see the first-subgroup result via broadcast below
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
// compute inv_r and coeff once per row and broadcast to the whole work-group
|
|
618
|
+
float inv_r = 0.f;
|
|
619
|
+
float coeff = 0.f;
|
|
620
|
+
if (tid == 0) {
|
|
621
|
+
const float sum_xx_f = xx_total.x() + xx_total.y();
|
|
622
|
+
const float sum_xdz_f = xg_total.x() + xg_total.y();
|
|
623
|
+
const float mean_eps = sum_xx_f / (float) D + eps;
|
|
624
|
+
const float sum_eps = sum_xx_f + eps * (float) D;
|
|
625
|
+
inv_r = sycl::rsqrt(mean_eps);
|
|
626
|
+
coeff = -sum_xdz_f / sum_eps;
|
|
627
|
+
}
|
|
628
|
+
inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
|
|
629
|
+
coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
|
|
630
|
+
|
|
631
|
+
for (int64_t col = tid; col < D; col += WG) {
|
|
632
|
+
d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
|
|
633
|
+
}
|
|
634
|
+
});
|
|
635
|
+
});
|
|
636
|
+
|
|
637
|
+
}
|
|
638
|
+
|
|
483
639
|
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
484
640
|
|
|
485
641
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
@@ -19,6 +19,8 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
|
19
19
|
|
|
20
20
|
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
21
21
|
|
|
22
|
+
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
23
|
+
|
|
22
24
|
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
23
25
|
|
|
24
26
|
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2025 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
//#include "common.hpp"
|
|
14
|
+
#include "pad.hpp"
|
|
15
|
+
|
|
16
|
+
static void pad_f32(const float * src, float * dst,
|
|
17
|
+
const int lp0, const int rp0, const int lp1, const int rp1,
|
|
18
|
+
const int lp2, const int rp2, const int lp3, const int rp3,
|
|
19
|
+
const int ne0, const int ne1, const int ne2, const int ne3,
|
|
20
|
+
sycl::nd_item<3> item_ct1) {
|
|
21
|
+
int i0 = item_ct1.get_local_id(2) +
|
|
22
|
+
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
|
23
|
+
int i1 = item_ct1.get_group(1);
|
|
24
|
+
int i2 = item_ct1.get_group(0) % ne2;
|
|
25
|
+
int i3 = item_ct1.get_group(0) / ne2;
|
|
26
|
+
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
|
27
|
+
return;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
// operation
|
|
31
|
+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
32
|
+
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
|
33
|
+
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
|
34
|
+
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
|
35
|
+
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
|
36
|
+
const int64_t i00 = i0 - lp0;
|
|
37
|
+
const int64_t i01 = i1 - lp1;
|
|
38
|
+
const int64_t i02 = i2 - lp2;
|
|
39
|
+
const int64_t i03 = i3 - lp3;
|
|
40
|
+
const int64_t ne02 = ne2 - lp2 - rp2;
|
|
41
|
+
const int64_t ne01 = ne1 - lp1 - rp1;
|
|
42
|
+
const int64_t ne00 = ne0 - lp0 - rp0;
|
|
43
|
+
|
|
44
|
+
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
|
|
45
|
+
i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
|
46
|
+
|
|
47
|
+
dst[dst_idx] = src[src_idx];
|
|
48
|
+
} else {
|
|
49
|
+
dst[dst_idx] = 0.0f;
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
static void pad_f32_sycl(const float *src, float *dst, const int lp0,
|
|
54
|
+
const int rp0, const int lp1, const int rp1,
|
|
55
|
+
const int lp2, const int rp2, const int lp3,
|
|
56
|
+
const int rp3, const int ne0, const int ne1,
|
|
57
|
+
const int ne2, const int ne3,
|
|
58
|
+
dpct::queue_ptr stream) {
|
|
59
|
+
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
|
60
|
+
dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
|
61
|
+
stream->parallel_for(
|
|
62
|
+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
|
63
|
+
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
|
64
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
65
|
+
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
|
|
66
|
+
ne2, ne3, item_ct1);
|
|
67
|
+
});
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
71
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
72
|
+
const float * src0_d = (const float *)src0->data;
|
|
73
|
+
float * dst_d = (float *)dst->data;
|
|
74
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
75
|
+
|
|
76
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
77
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
78
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
79
|
+
|
|
80
|
+
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
|
81
|
+
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
|
82
|
+
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
|
83
|
+
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
|
84
|
+
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
|
85
|
+
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
|
86
|
+
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
|
87
|
+
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
|
88
|
+
|
|
89
|
+
pad_f32_sycl(src0_d, dst_d,
|
|
90
|
+
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
|
91
|
+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
95
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
96
|
+
ggml_sycl_op_pad(ctx, dst);
|
|
97
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2025 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
#ifndef GGML_SYCL_PAD_HPP
|
|
14
|
+
#define GGML_SYCL_PAD_HPP
|
|
15
|
+
|
|
16
|
+
#include "common.hpp"
|
|
17
|
+
|
|
18
|
+
#define SYCL_PAD_BLOCK_SIZE 256
|
|
19
|
+
|
|
20
|
+
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
21
|
+
|
|
22
|
+
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
23
|
+
|
|
24
|
+
#endif // GGML_SYCL_PAD_HPP
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
#include "pad_reflect_1d.hpp"
|
|
2
|
+
|
|
3
|
+
static void pad_reflect_1d_kernel_f32(
|
|
4
|
+
const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,
|
|
5
|
+
const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,
|
|
6
|
+
const int64_t ne03, const int64_t nb00, const int64_t nb01,
|
|
7
|
+
const int64_t nb02, const int64_t nb03, const int64_t nb0,
|
|
8
|
+
const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,
|
|
9
|
+
const int p1, sycl::nd_item<3> item_ct1) {
|
|
10
|
+
|
|
11
|
+
const int64_t i3 = item_ct1.get_group(0);
|
|
12
|
+
const int64_t i2 = item_ct1.get_group(1);
|
|
13
|
+
|
|
14
|
+
const sycl::uint2 div_mod_packed =
|
|
15
|
+
fast_div_modulo(item_ct1.get_group(2), ne01);
|
|
16
|
+
const int64_t tile1 = div_mod_packed.y();
|
|
17
|
+
const int64_t tile0 = div_mod_packed.x();
|
|
18
|
+
const int64_t i1 = tile1;
|
|
19
|
+
const int64_t i0 =
|
|
20
|
+
item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);
|
|
21
|
+
|
|
22
|
+
if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {
|
|
23
|
+
return;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
const char *src0_ptr =
|
|
27
|
+
(const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
|
|
28
|
+
char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
|
|
29
|
+
|
|
30
|
+
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
|
|
31
|
+
int64_t src_idx;
|
|
32
|
+
|
|
33
|
+
if (rel_i0 < 0) {
|
|
34
|
+
// Left padding - reflect
|
|
35
|
+
src_idx = -rel_i0;
|
|
36
|
+
} else if (rel_i0 < ne00) {
|
|
37
|
+
// Middle - copy
|
|
38
|
+
src_idx = rel_i0;
|
|
39
|
+
} else {
|
|
40
|
+
// Right padding - reflect
|
|
41
|
+
src_idx = 2 * ne00 - 2 - rel_i0;
|
|
42
|
+
}
|
|
43
|
+
const float value = *(const float *)(src0_ptr + src_idx * nb00);
|
|
44
|
+
*(float *)(dst_ptr + i0 * nb0) = value;
|
|
45
|
+
|
|
46
|
+
GGML_UNUSED(p1);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,
|
|
50
|
+
ggml_tensor *dst) {
|
|
51
|
+
|
|
52
|
+
const ggml_tensor *src0 = dst->src[0];
|
|
53
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
54
|
+
|
|
55
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
56
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
57
|
+
|
|
58
|
+
const int32_t *opts = (const int32_t *)dst->op_params;
|
|
59
|
+
const int p0 = opts[0];
|
|
60
|
+
const int p1 = opts[1];
|
|
61
|
+
|
|
62
|
+
const int64_t ne00 = src0->ne[0];
|
|
63
|
+
const int64_t ne01 = src0->ne[1];
|
|
64
|
+
const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);
|
|
65
|
+
const int64_t ne02 = src0->ne[2];
|
|
66
|
+
const int64_t ne03 = src0->ne[3];
|
|
67
|
+
|
|
68
|
+
const int64_t ne0 = dst->ne[0];
|
|
69
|
+
|
|
70
|
+
GGML_ASSERT(ne0 == ne00 + p0 + p1);
|
|
71
|
+
|
|
72
|
+
constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;
|
|
73
|
+
const int64_t tiles0 = (ne0 + bx - 1) / bx;
|
|
74
|
+
const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,
|
|
75
|
+
(unsigned)ne03);
|
|
76
|
+
const dpct::dim3 block_dims((unsigned)bx, 1, 1);
|
|
77
|
+
|
|
78
|
+
stream->submit([&](sycl::handler &cgh) {
|
|
79
|
+
auto src0_data_ct0 = src0->data;
|
|
80
|
+
auto dst_data_ct1 = dst->data;
|
|
81
|
+
auto src0_nb_ct7 = src0->nb[0];
|
|
82
|
+
auto src0_nb_ct8 = src0->nb[1];
|
|
83
|
+
auto src0_nb_ct9 = src0->nb[2];
|
|
84
|
+
auto src0_nb_ct10 = src0->nb[3];
|
|
85
|
+
auto dst_nb_ct11 = dst->nb[0];
|
|
86
|
+
auto dst_nb_ct12 = dst->nb[1];
|
|
87
|
+
auto dst_nb_ct13 = dst->nb[2];
|
|
88
|
+
auto dst_nb_ct14 = dst->nb[3];
|
|
89
|
+
|
|
90
|
+
cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
91
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
92
|
+
pad_reflect_1d_kernel_f32(
|
|
93
|
+
src0_data_ct0, dst_data_ct1, ne0, ne00,
|
|
94
|
+
ne01_packed, ne02, ne03, src0_nb_ct7,
|
|
95
|
+
src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,
|
|
96
|
+
dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,
|
|
97
|
+
dst_nb_ct14, p0, p1, item_ct1);
|
|
98
|
+
});
|
|
99
|
+
});
|
|
100
|
+
}
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
#ifndef GGML_SYCL_PAD_REFLECT_1D_HPP
|
|
2
|
+
#define GGML_SYCL_PAD_REFLECT_1D_HPP
|
|
3
|
+
|
|
4
|
+
#include "common.hpp"
|
|
5
|
+
|
|
6
|
+
#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256
|
|
7
|
+
|
|
8
|
+
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
9
|
+
|
|
10
|
+
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP
|
|
@@ -31,6 +31,7 @@
|
|
|
31
31
|
#define SYCL_SQRT_BLOCK_SIZE 256
|
|
32
32
|
#define SYCL_SIN_BLOCK_SIZE 256
|
|
33
33
|
#define SYCL_SQR_BLOCK_SIZE 256
|
|
34
|
+
#define SYCL_SET_BLOCK_SIZE 256
|
|
34
35
|
#define SYCL_CPY_BLOCK_SIZE 32
|
|
35
36
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
|
36
37
|
#define SYCL_CLAMP_BLOCK_SIZE 256
|
|
@@ -49,6 +50,7 @@
|
|
|
49
50
|
#define SYCL_ARGMAX_BLOCK_SIZE 256
|
|
50
51
|
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
|
51
52
|
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
|
53
|
+
#define SYCL_ARANGE_BLOCK_SIZE 256
|
|
52
54
|
|
|
53
55
|
// dmmv = dequantize_mul_mat_vec
|
|
54
56
|
#ifndef GGML_SYCL_DMMV_X
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
#include "repeat_back.hpp"
|
|
2
|
+
|
|
3
|
+
#include "common.hpp"
|
|
4
|
+
|
|
5
|
+
#define GGML_ASSERT_TENSOR_FITS_INT(t) \
|
|
6
|
+
GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX)
|
|
7
|
+
|
|
8
|
+
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
9
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
10
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
11
|
+
|
|
12
|
+
const float * src0_dd = (const float *) dst->src[0]->data;
|
|
13
|
+
float * dst_dd = (float *) dst->data;
|
|
14
|
+
|
|
15
|
+
GGML_ASSERT_TENSOR_FITS_INT(dst);
|
|
16
|
+
GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]);
|
|
17
|
+
|
|
18
|
+
const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
|
|
19
|
+
const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
|
|
20
|
+
ne03 = dst->src[0]->ne[3];
|
|
21
|
+
|
|
22
|
+
const int nr0 = ne00 / ne0;
|
|
23
|
+
const int nr1 = ne01 / ne1;
|
|
24
|
+
const int nr2 = ne02 / ne2;
|
|
25
|
+
const int nr3 = ne03 / ne3;
|
|
26
|
+
|
|
27
|
+
const int nb0 = dst->src[0]->nb[0];
|
|
28
|
+
const int nb1 = dst->src[0]->nb[1];
|
|
29
|
+
const int nb2 = dst->src[0]->nb[2];
|
|
30
|
+
const int nb3 = dst->src[0]->nb[3];
|
|
31
|
+
|
|
32
|
+
const char * base = (const char *) src0_dd;
|
|
33
|
+
|
|
34
|
+
const size_t total = (size_t) ne0 * ne1 * ne2 * ne3;
|
|
35
|
+
constexpr int BLOCK_SIZE = 256;
|
|
36
|
+
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
37
|
+
|
|
38
|
+
const float inv_ne0 = 1.0f / ne0;
|
|
39
|
+
const float inv_ne_01 = 1.0f / (ne0 * ne1);
|
|
40
|
+
const float inv_ne_012 = 1.0f / (ne0 * ne1 * ne2);
|
|
41
|
+
const int repeat_count = nr0 * nr1 * nr2 * nr3;
|
|
42
|
+
|
|
43
|
+
queue_ptr stream = ctx.stream();
|
|
44
|
+
|
|
45
|
+
stream->parallel_for(
|
|
46
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
|
|
47
|
+
[=](sycl::nd_item<1> item_ct1) {
|
|
48
|
+
const size_t i = item_ct1.get_global_linear_id();
|
|
49
|
+
if (i >= total) {
|
|
50
|
+
return;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
const int i3 = (int) (i * inv_ne_012);
|
|
54
|
+
const int i2 = (int) (i * inv_ne_01) - i3 * ne2;
|
|
55
|
+
const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1;
|
|
56
|
+
const int i0 = i - (int) (i * inv_ne0) * ne0;
|
|
57
|
+
|
|
58
|
+
int j0 = 0, j1 = 0, j2 = 0, j3 = 0;
|
|
59
|
+
float acc = 0.0f;
|
|
60
|
+
|
|
61
|
+
for (int j = 0; j < repeat_count; ++j) {
|
|
62
|
+
const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 +
|
|
63
|
+
(i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3);
|
|
64
|
+
acc += *ptr;
|
|
65
|
+
|
|
66
|
+
int carry = (++j0 >= nr0);
|
|
67
|
+
j0 -= carry * nr0;
|
|
68
|
+
carry = (carry && (++j1 >= nr1));
|
|
69
|
+
j1 -= carry * nr1;
|
|
70
|
+
carry = (carry && (++j2 >= nr2));
|
|
71
|
+
j2 -= carry * nr2;
|
|
72
|
+
j3 += carry;
|
|
73
|
+
}
|
|
74
|
+
dst_dd[i] = acc;
|
|
75
|
+
});
|
|
76
|
+
}
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
#include "roll.hpp"
|
|
2
|
+
#include "common.hpp"
|
|
3
|
+
|
|
4
|
+
using namespace sycl;
|
|
5
|
+
|
|
6
|
+
static inline int wrap_add(int i, int shift, int n) {
|
|
7
|
+
|
|
8
|
+
int s = i + shift;
|
|
9
|
+
return (s >= n) ? (s - n) : s;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
static void kernel_roll_fused_i0_i1(
|
|
13
|
+
queue &q,
|
|
14
|
+
const float *src_d,
|
|
15
|
+
float *dst_d,
|
|
16
|
+
int ne0, int ne1, int ne2, int ne3,
|
|
17
|
+
int sh0, int sh1, int sh2, int sh3)
|
|
18
|
+
{
|
|
19
|
+
if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
const int stride1 = ne0;
|
|
23
|
+
const int stride2 = ne0 * ne1;
|
|
24
|
+
const int stride3 = ne0 * ne1 * ne2;
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
const int shNe0 = (ne0 - sh0) % ne0;
|
|
28
|
+
const int shNe1 = (ne1 - sh1) % ne1;
|
|
29
|
+
const int shNe2 = (ne2 - sh2) % ne2;
|
|
30
|
+
const int shNe3 = (ne3 - sh3) % ne3;
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
const size_t g0 = (size_t) ne3;
|
|
34
|
+
const size_t g1 = (size_t) ne2;
|
|
35
|
+
const size_t g2 = (size_t) (ne1 * ne0);
|
|
36
|
+
|
|
37
|
+
const range<3> global{ g0, g1, g2 };
|
|
38
|
+
|
|
39
|
+
q.submit([&](handler &h) {
|
|
40
|
+
h.parallel_for(global, [=](id<3> idx) {
|
|
41
|
+
const int i3 = (int) idx[0];
|
|
42
|
+
const int i2 = (int) idx[1];
|
|
43
|
+
|
|
44
|
+
const int fused = (int) idx[2];
|
|
45
|
+
const int i1 = fused / ne0;
|
|
46
|
+
const int i0 = fused - i1 * ne0; // fused % ne0
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
const int idx_dst = i0
|
|
50
|
+
+ i1 * stride1
|
|
51
|
+
+ i2 * stride2
|
|
52
|
+
+ i3 * stride3;
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
const int s0 = wrap_add(i0, shNe0, ne0);
|
|
56
|
+
const int s1 = wrap_add(i1, shNe1, ne1);
|
|
57
|
+
const int s2 = wrap_add(i2, shNe2, ne2);
|
|
58
|
+
const int s3 = wrap_add(i3, shNe3, ne3);
|
|
59
|
+
|
|
60
|
+
const int idx_src = s0
|
|
61
|
+
+ s1 * stride1
|
|
62
|
+
+ s2 * stride2
|
|
63
|
+
+ s3 * stride3;
|
|
64
|
+
|
|
65
|
+
dst_d[idx_dst] = src_d[idx_src];
|
|
66
|
+
});
|
|
67
|
+
});
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
|
71
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
72
|
+
|
|
73
|
+
const ggml_tensor *src = dst->src[0];
|
|
74
|
+
GGML_ASSERT(src && src->type == GGML_TYPE_F32);
|
|
75
|
+
|
|
76
|
+
const int ne0 = (int) dst->ne[0];
|
|
77
|
+
const int ne1 = (int) dst->ne[1];
|
|
78
|
+
const int ne2 = (int) dst->ne[2];
|
|
79
|
+
const int ne3 = (int) dst->ne[3];
|
|
80
|
+
|
|
81
|
+
const int32_t *params = (const int32_t *) dst->op_params;
|
|
82
|
+
int shift0 = params[0];
|
|
83
|
+
int shift1 = params[1];
|
|
84
|
+
int shift2 = params[2];
|
|
85
|
+
int shift3 = params[3];
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
if ((shift0 | shift1 | shift2 | shift3) == 0) {
|
|
89
|
+
const size_t nb = ggml_nbytes(src);
|
|
90
|
+
queue *q = ctx.stream();
|
|
91
|
+
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
|
|
92
|
+
return;
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
auto norm = [](int sh, int n) -> int {
|
|
96
|
+
if (n <= 0) return 0;
|
|
97
|
+
sh %= n;
|
|
98
|
+
if (sh < 0) sh += n;
|
|
99
|
+
return sh;
|
|
100
|
+
};
|
|
101
|
+
shift0 = norm(shift0, ne0);
|
|
102
|
+
shift1 = norm(shift1, ne1);
|
|
103
|
+
shift2 = norm(shift2, ne2);
|
|
104
|
+
shift3 = norm(shift3, ne3);
|
|
105
|
+
|
|
106
|
+
try {
|
|
107
|
+
queue *q = ctx.stream();
|
|
108
|
+
|
|
109
|
+
const float *src_d = (const float *) src->data;
|
|
110
|
+
float *dst_d = (float *) dst->data;
|
|
111
|
+
GGML_ASSERT(src_d && dst_d);
|
|
112
|
+
|
|
113
|
+
kernel_roll_fused_i0_i1(
|
|
114
|
+
*q, src_d, dst_d,
|
|
115
|
+
ne0, ne1, ne2, ne3,
|
|
116
|
+
shift0, shift1, shift2, shift3
|
|
117
|
+
);
|
|
118
|
+
} catch (const std::exception &e) {
|
|
119
|
+
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
|
|
120
|
+
throw;
|
|
121
|
+
}
|
|
122
|
+
}
|