whispercpp 1.3.4 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +47 -23
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/cli/cli.cpp +121 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +10 -11
- data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
- data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
- data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
- data/ext/sources/examples/talk-llama/llama-context.h +57 -9
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
- data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
- data/ext/sources/examples/talk-llama/llama-model.h +44 -3
- data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
- data/ext/sources/examples/talk-llama/llama.cpp +729 -2
- data/ext/sources/examples/talk-llama/llama.h +152 -14
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +82 -54
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +4 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +190 -12
- data/ext/sources/ggml/src/CMakeLists.txt +82 -11
- data/ext/sources/ggml/src/ggml-alloc.c +124 -41
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
- data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
- data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
- data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- data/ext/sources/ggml/src/ggml-impl.h +67 -6
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +425 -33
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +101 -35
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +119 -2
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +7 -0
- data/whispercpp.gemspec +1 -1
- metadata +287 -34
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -151,72 +151,50 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
|
|
151
151
|
}
|
|
152
152
|
|
|
153
153
|
template<typename T>
|
|
154
|
-
static
|
|
155
|
-
|
|
156
|
-
dst[i] = op_sgn(x[i]);
|
|
157
|
-
}
|
|
154
|
+
static __dpct_inline__ T op_floor(T x) {
|
|
155
|
+
return sycl::floor(x);
|
|
158
156
|
}
|
|
159
157
|
|
|
160
158
|
template<typename T>
|
|
161
|
-
static
|
|
162
|
-
|
|
163
|
-
dst[i] = op_abs(x[i]);
|
|
164
|
-
}
|
|
159
|
+
static __dpct_inline__ T op_ceil(T x) {
|
|
160
|
+
return sycl::ceil(x);
|
|
165
161
|
}
|
|
166
162
|
|
|
167
163
|
template<typename T>
|
|
168
|
-
static
|
|
169
|
-
|
|
170
|
-
dst[i] = op_elu(x[i]);
|
|
171
|
-
}
|
|
164
|
+
static __dpct_inline__ T op_round(T x) {
|
|
165
|
+
return sycl::round(x);
|
|
172
166
|
}
|
|
173
167
|
|
|
174
168
|
template<typename T>
|
|
175
|
-
static
|
|
169
|
+
static __dpct_inline__ T op_trunc(T x) {
|
|
170
|
+
return sycl::trunc(x);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
template<typename T, typename F>
|
|
174
|
+
static void unary_op_generic_kernel(
|
|
175
|
+
const T * x,
|
|
176
|
+
T * dst,
|
|
177
|
+
const int k,
|
|
178
|
+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
|
|
179
|
+
const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3,
|
|
180
|
+
const size_t nbd0, const size_t nbd1, const size_t nbd2, const size_t nbd3,
|
|
181
|
+
const sycl::nd_item<1> & item_ct1,
|
|
182
|
+
F func) {
|
|
183
|
+
|
|
184
|
+
(void) ne3;
|
|
176
185
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
186
|
+
const int64_t i0 = i % ne0;
|
|
187
|
+
const int64_t i1 = (i / ne0) % ne1;
|
|
188
|
+
const int64_t i2 = (i / (ne0*ne1)) % ne2;
|
|
189
|
+
const int64_t i3 = i / (ne0*ne1*ne2);
|
|
180
190
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
184
|
-
dst[i] = op_silu(x[i]);
|
|
185
|
-
}
|
|
186
|
-
}
|
|
191
|
+
const char * src_base = (const char *) x;
|
|
192
|
+
char * dst_base = (char *) dst;
|
|
187
193
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
191
|
-
dst[i] = op_gelu_quick(x[i]);
|
|
192
|
-
}
|
|
193
|
-
}
|
|
194
|
+
const T * srcp = (const T *)(src_base + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 );
|
|
195
|
+
T * dstp = (T *)(dst_base + i0*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3);
|
|
194
196
|
|
|
195
|
-
|
|
196
|
-
static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
197
|
-
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
198
|
-
dst[i] = op_gelu_erf(x[i]);
|
|
199
|
-
}
|
|
200
|
-
}
|
|
201
|
-
|
|
202
|
-
template<typename T>
|
|
203
|
-
static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
204
|
-
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
205
|
-
dst[i] = op_tanh(x[i]);
|
|
206
|
-
}
|
|
207
|
-
}
|
|
208
|
-
|
|
209
|
-
template<typename T>
|
|
210
|
-
static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
211
|
-
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
212
|
-
dst[i] = op_relu(x[i]);
|
|
213
|
-
}
|
|
214
|
-
}
|
|
215
|
-
|
|
216
|
-
template<typename T>
|
|
217
|
-
static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
218
|
-
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
219
|
-
dst[i] = op_sigmoid(x[i]);
|
|
197
|
+
*dstp = func(*srcp);
|
|
220
198
|
}
|
|
221
199
|
}
|
|
222
200
|
|
|
@@ -242,65 +220,59 @@ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::n
|
|
|
242
220
|
}
|
|
243
221
|
|
|
244
222
|
template<typename T>
|
|
245
|
-
static void
|
|
223
|
+
static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
246
224
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
247
|
-
dst[i] =
|
|
225
|
+
dst[i] = op_log(x[i]);
|
|
248
226
|
}
|
|
249
227
|
}
|
|
250
228
|
|
|
251
|
-
template<typename T>
|
|
252
|
-
static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
253
|
-
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
254
|
-
dst[i] = op_hardswish(x[i]);
|
|
255
|
-
}
|
|
256
|
-
}
|
|
257
229
|
|
|
258
230
|
template<typename T>
|
|
259
|
-
static void
|
|
231
|
+
static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
|
|
260
232
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
261
|
-
dst[i] =
|
|
233
|
+
dst[i] = op_leaky_relu(x[i], negative_slope);
|
|
262
234
|
}
|
|
263
235
|
}
|
|
264
236
|
|
|
265
237
|
template<typename T>
|
|
266
|
-
static void
|
|
238
|
+
static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
267
239
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
268
|
-
dst[i] =
|
|
240
|
+
dst[i] = op_sqr(x[i]);
|
|
269
241
|
}
|
|
270
242
|
}
|
|
271
243
|
|
|
272
244
|
template<typename T>
|
|
273
|
-
static void
|
|
245
|
+
static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
|
|
274
246
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
275
|
-
dst[i] =
|
|
247
|
+
dst[i] = op_clamp(x[i], min_val, max_val);
|
|
276
248
|
}
|
|
277
249
|
}
|
|
278
250
|
|
|
279
251
|
template<typename T>
|
|
280
|
-
static void
|
|
252
|
+
static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
281
253
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
282
|
-
dst[i] =
|
|
254
|
+
dst[i] = op_floor(x[i]);
|
|
283
255
|
}
|
|
284
256
|
}
|
|
285
257
|
|
|
286
258
|
template<typename T>
|
|
287
|
-
static void
|
|
259
|
+
static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
288
260
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
289
|
-
dst[i] =
|
|
261
|
+
dst[i] = op_ceil(x[i]);
|
|
290
262
|
}
|
|
291
263
|
}
|
|
292
264
|
|
|
293
265
|
template<typename T>
|
|
294
|
-
static void
|
|
266
|
+
static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
295
267
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
296
|
-
dst[i] =
|
|
268
|
+
dst[i] = op_round(x[i]);
|
|
297
269
|
}
|
|
298
270
|
}
|
|
299
271
|
|
|
300
272
|
template<typename T>
|
|
301
|
-
static void
|
|
273
|
+
static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
|
302
274
|
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
303
|
-
dst[i] =
|
|
275
|
+
dst[i] = op_trunc(x[i]);
|
|
304
276
|
}
|
|
305
277
|
}
|
|
306
278
|
|
|
@@ -328,26 +300,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
|
|
328
300
|
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
|
329
301
|
}
|
|
330
302
|
|
|
331
|
-
template <typename T>
|
|
332
|
-
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
|
333
|
-
const sycl::nd_item<3> &item_ct1) {
|
|
334
|
-
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
|
335
|
-
if (nidx >= ne0) {
|
|
336
|
-
return;
|
|
337
|
-
}
|
|
338
|
-
|
|
339
|
-
// operation
|
|
340
|
-
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
|
341
|
-
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
|
342
|
-
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
|
|
343
|
-
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
|
344
|
-
item_ct1.get_group(0) * ne00 * ne01;
|
|
345
|
-
dst[offset_dst] = x[offset_src];
|
|
346
|
-
} else {
|
|
347
|
-
dst[offset_dst] = static_cast<T>(0.0f);
|
|
348
|
-
}
|
|
349
|
-
}
|
|
350
|
-
|
|
351
303
|
template<typename T>
|
|
352
304
|
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
|
|
353
305
|
const sycl::nd_item<1> &item_ct1) {
|
|
@@ -417,6 +369,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
|
|
417
369
|
});
|
|
418
370
|
}
|
|
419
371
|
|
|
372
|
+
template<typename T>
|
|
373
|
+
static void arange_kernel(T * dst, const int k, T start, T step,
|
|
374
|
+
const sycl::nd_item<1> &item_ct1) {
|
|
375
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
|
376
|
+
dst[i] = start + static_cast<T>(i) * step;
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
|
|
420
380
|
template<typename T>
|
|
421
381
|
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
|
422
382
|
const int nb02, const int nb03, const int ne10, const int ne11,
|
|
@@ -431,18 +391,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
|
|
431
391
|
});
|
|
432
392
|
}
|
|
433
393
|
|
|
434
|
-
template<typename T>
|
|
435
|
-
static void pad_sycl(const T *x, T *dst, const int ne00,
|
|
436
|
-
const int ne01, const int ne02, const int ne0,
|
|
437
|
-
const int ne1, const int ne2, queue_ptr stream) {
|
|
438
|
-
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
|
439
|
-
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
|
440
|
-
stream->parallel_for(
|
|
441
|
-
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
|
442
|
-
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
|
443
|
-
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
|
444
|
-
}
|
|
445
|
-
|
|
446
394
|
template<typename KernelInvoker, typename... Args>
|
|
447
395
|
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
|
448
396
|
#if defined (GGML_SYCL_F16)
|
|
@@ -596,199 +544,142 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
|
|
596
544
|
}
|
|
597
545
|
}
|
|
598
546
|
|
|
599
|
-
template<typename
|
|
600
|
-
static inline void
|
|
601
|
-
|
|
602
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
|
603
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
604
|
-
#else
|
|
605
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
606
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
607
|
-
#endif
|
|
608
|
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
|
609
|
-
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
|
610
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
|
611
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
612
|
-
switch (dst->type) {
|
|
613
|
-
#if defined (GGML_SYCL_F16)
|
|
614
|
-
case GGML_TYPE_F16:
|
|
615
|
-
{
|
|
616
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
|
617
|
-
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
|
618
|
-
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
|
619
|
-
break;
|
|
620
|
-
}
|
|
621
|
-
#endif
|
|
622
|
-
case GGML_TYPE_F32:
|
|
623
|
-
{
|
|
624
|
-
auto data_pts = cast_data<float>(dst);
|
|
625
|
-
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
|
626
|
-
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
|
627
|
-
break;
|
|
628
|
-
}
|
|
629
|
-
default:
|
|
630
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
|
631
|
-
}
|
|
632
|
-
}
|
|
547
|
+
template<typename F>
|
|
548
|
+
static inline void ggml_sycl_op_unary(
|
|
549
|
+
ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
|
|
633
550
|
|
|
634
|
-
|
|
551
|
+
ggml_tensor * src0 = dst->src[0];
|
|
635
552
|
|
|
553
|
+
const int64_t ne0 = dst->ne[0];
|
|
554
|
+
const int64_t ne1 = dst->ne[1];
|
|
555
|
+
const int64_t ne2 = dst->ne[2];
|
|
556
|
+
const int64_t ne3 = dst->ne[3];
|
|
636
557
|
|
|
558
|
+
const size_t nb0 = src0->nb[0];
|
|
559
|
+
const size_t nb1 = src0->nb[1];
|
|
560
|
+
const size_t nb2 = src0->nb[2];
|
|
561
|
+
const size_t nb3 = src0->nb[3];
|
|
562
|
+
|
|
563
|
+
const size_t nbd0 = dst->nb[0];
|
|
564
|
+
const size_t nbd1 = dst->nb[1];
|
|
565
|
+
const size_t nbd2 = dst->nb[2];
|
|
566
|
+
const size_t nbd3 = dst->nb[3];
|
|
637
567
|
|
|
638
|
-
static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
639
568
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
|
640
|
-
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
|
569
|
+
[=](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
|
570
|
+
|
|
641
571
|
const int num_blocks = ceil_div(k_elements, 256);
|
|
572
|
+
|
|
642
573
|
stream->parallel_for(
|
|
643
574
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
|
644
575
|
sycl::range<1>(256)),
|
|
645
576
|
[=](sycl::nd_item<1> item_ct1) {
|
|
646
|
-
|
|
577
|
+
unary_op_generic_kernel(
|
|
578
|
+
src, dst_ptr, k_elements,
|
|
579
|
+
ne0, ne1, ne2, ne3,
|
|
580
|
+
nb0, nb1, nb2, nb3,
|
|
581
|
+
nbd0, nbd1, nbd2, nbd3,
|
|
582
|
+
item_ct1,
|
|
583
|
+
func
|
|
584
|
+
);
|
|
647
585
|
});
|
|
648
586
|
});
|
|
649
587
|
}
|
|
650
588
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
589
|
+
|
|
590
|
+
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
591
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
592
|
+
float start, stop, step;
|
|
593
|
+
memcpy(&start, dst->op_params, sizeof(float));
|
|
594
|
+
memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
|
|
595
|
+
memcpy(&step, (float *) dst->op_params + 2, sizeof(float));
|
|
596
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
597
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
598
|
+
float * dst_ptr = (float *)dst->data;
|
|
599
|
+
const int k = (int)ggml_nelements(dst);
|
|
600
|
+
const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
|
|
601
|
+
stream->parallel_for(
|
|
602
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
|
|
603
|
+
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
|
|
604
|
+
[=](sycl::nd_item<1> item_ct1) {
|
|
605
|
+
arange_kernel(dst_ptr, k, start, step, item_ct1);
|
|
661
606
|
});
|
|
662
607
|
}
|
|
663
608
|
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
609
|
+
} // namespace ggml_sycl_detail
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
614
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
615
|
+
return op_sgn(x);
|
|
616
|
+
});
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
621
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
622
|
+
return op_abs(x);
|
|
623
|
+
});
|
|
675
624
|
}
|
|
676
625
|
|
|
626
|
+
static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
627
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
628
|
+
return op_elu(x);
|
|
629
|
+
});
|
|
630
|
+
}
|
|
677
631
|
static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
678
|
-
ggml_sycl_detail::
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
stream->parallel_for(
|
|
682
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
|
|
683
|
-
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
|
684
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
685
|
-
unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
686
|
-
});
|
|
687
|
-
});
|
|
632
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
633
|
+
return op_silu(x);
|
|
634
|
+
});
|
|
688
635
|
}
|
|
689
636
|
|
|
690
637
|
static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
691
|
-
ggml_sycl_detail::
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
stream->parallel_for(
|
|
695
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
|
696
|
-
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
|
697
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
698
|
-
unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
699
|
-
});
|
|
700
|
-
});
|
|
638
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
639
|
+
return op_gelu(x);
|
|
640
|
+
});
|
|
701
641
|
}
|
|
702
642
|
|
|
703
|
-
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
|
704
|
-
ggml_sycl_detail::
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
stream->parallel_for(
|
|
708
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
|
709
|
-
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
|
710
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
711
|
-
unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
712
|
-
});
|
|
713
|
-
});
|
|
643
|
+
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
644
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
645
|
+
return op_gelu_quick(x);
|
|
646
|
+
});
|
|
714
647
|
}
|
|
715
648
|
|
|
716
|
-
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
|
717
|
-
ggml_sycl_detail::
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
stream->parallel_for(
|
|
721
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
|
722
|
-
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
|
723
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
724
|
-
unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
725
|
-
});
|
|
726
|
-
});
|
|
649
|
+
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
650
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
651
|
+
return op_gelu_erf(x);
|
|
652
|
+
});
|
|
727
653
|
}
|
|
728
654
|
|
|
729
655
|
static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
730
|
-
ggml_sycl_detail::
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
stream->parallel_for(
|
|
734
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
|
|
735
|
-
sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
|
|
736
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
737
|
-
unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
738
|
-
});
|
|
739
|
-
});
|
|
656
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
657
|
+
return op_tanh(x);
|
|
658
|
+
});
|
|
740
659
|
}
|
|
741
660
|
|
|
742
661
|
static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
743
|
-
ggml_sycl_detail::
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
stream->parallel_for(
|
|
747
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
|
748
|
-
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
|
749
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
750
|
-
unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
751
|
-
});
|
|
752
|
-
});
|
|
662
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
663
|
+
return op_relu(x);
|
|
664
|
+
});
|
|
753
665
|
}
|
|
754
666
|
|
|
755
667
|
static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
756
|
-
ggml_sycl_detail::
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
stream->parallel_for(
|
|
760
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
|
|
761
|
-
sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
|
762
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
763
|
-
unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
764
|
-
});
|
|
765
|
-
});
|
|
668
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
669
|
+
return op_hardsigmoid(x);
|
|
670
|
+
});
|
|
766
671
|
}
|
|
767
672
|
|
|
768
673
|
static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
769
|
-
ggml_sycl_detail::
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
stream->parallel_for(
|
|
773
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
|
|
774
|
-
sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
|
|
775
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
776
|
-
unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
777
|
-
});
|
|
778
|
-
});
|
|
674
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
675
|
+
return op_hardswish(x);
|
|
676
|
+
});
|
|
779
677
|
}
|
|
780
678
|
|
|
781
679
|
static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
782
|
-
ggml_sycl_detail::
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
stream->parallel_for(
|
|
786
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
|
787
|
-
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
|
788
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
789
|
-
unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
790
|
-
});
|
|
791
|
-
});
|
|
680
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
681
|
+
return op_exp(x);
|
|
682
|
+
});
|
|
792
683
|
}
|
|
793
684
|
|
|
794
685
|
static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
@@ -805,42 +696,22 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
|
|
805
696
|
}
|
|
806
697
|
|
|
807
698
|
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
808
|
-
ggml_sycl_detail::
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
stream->parallel_for(
|
|
812
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
|
813
|
-
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
|
814
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
815
|
-
unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
816
|
-
});
|
|
817
|
-
});
|
|
699
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
700
|
+
return op_neg(x);
|
|
701
|
+
});
|
|
818
702
|
}
|
|
819
703
|
|
|
704
|
+
|
|
820
705
|
static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
821
|
-
ggml_sycl_detail::
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
stream->parallel_for(
|
|
825
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
|
826
|
-
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
|
827
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
828
|
-
unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
829
|
-
});
|
|
830
|
-
});
|
|
706
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
707
|
+
return op_step(x);
|
|
708
|
+
});
|
|
831
709
|
}
|
|
832
710
|
|
|
833
711
|
static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
834
|
-
ggml_sycl_detail::
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
stream->parallel_for(
|
|
838
|
-
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
|
|
839
|
-
sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
|
|
840
|
-
[=](sycl::nd_item<1> item_ct1) {
|
|
841
|
-
unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
842
|
-
});
|
|
843
|
-
});
|
|
712
|
+
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
|
713
|
+
return op_sigmoid(x);
|
|
714
|
+
});
|
|
844
715
|
}
|
|
845
716
|
|
|
846
717
|
static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
@@ -919,14 +790,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
|
|
|
919
790
|
});
|
|
920
791
|
}
|
|
921
792
|
|
|
922
|
-
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
923
|
-
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
|
|
924
|
-
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
|
|
925
|
-
queue_ptr stream) {
|
|
926
|
-
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
|
927
|
-
});
|
|
928
|
-
}
|
|
929
|
-
|
|
930
793
|
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
931
794
|
float min_val;
|
|
932
795
|
float max_val;
|
|
@@ -944,6 +807,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
|
|
944
807
|
}, min_val, max_val);
|
|
945
808
|
}
|
|
946
809
|
|
|
810
|
+
static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
811
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
|
812
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
|
813
|
+
const int num_blocks = ceil_div(k_elements, 256);
|
|
814
|
+
stream->parallel_for(
|
|
815
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
|
816
|
+
sycl::range<1>(256)),
|
|
817
|
+
[=](sycl::nd_item<1> item_ct1) {
|
|
818
|
+
unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
819
|
+
});
|
|
820
|
+
});
|
|
821
|
+
}
|
|
822
|
+
|
|
823
|
+
static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
824
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
|
825
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
|
826
|
+
const int num_blocks = ceil_div(k_elements, 256);
|
|
827
|
+
stream->parallel_for(
|
|
828
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
|
829
|
+
sycl::range<1>(256)),
|
|
830
|
+
[=](sycl::nd_item<1> item_ct1) {
|
|
831
|
+
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
832
|
+
});
|
|
833
|
+
});
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
837
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
|
838
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
|
839
|
+
const int num_blocks = ceil_div(k_elements, 256);
|
|
840
|
+
stream->parallel_for(
|
|
841
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
|
842
|
+
sycl::range<1>(256)),
|
|
843
|
+
[=](sycl::nd_item<1> item_ct1) {
|
|
844
|
+
unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
845
|
+
});
|
|
846
|
+
});
|
|
847
|
+
}
|
|
848
|
+
|
|
849
|
+
static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
850
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
|
851
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
|
852
|
+
const int num_blocks = ceil_div(k_elements, 256);
|
|
853
|
+
stream->parallel_for(
|
|
854
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
|
855
|
+
sycl::range<1>(256)),
|
|
856
|
+
[=](sycl::nd_item<1> item_ct1) {
|
|
857
|
+
unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
|
|
858
|
+
});
|
|
859
|
+
});
|
|
860
|
+
}
|
|
861
|
+
|
|
947
862
|
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
|
948
863
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
949
864
|
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
|
@@ -996,6 +911,98 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
|
|
996
911
|
});
|
|
997
912
|
}
|
|
998
913
|
|
|
914
|
+
__dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
|
|
915
|
+
x = sycl::fmin(x, limit);
|
|
916
|
+
g = sycl::fmax(sycl::fmin(g, limit), -limit);
|
|
917
|
+
|
|
918
|
+
float out_glu = x / (1.0f + sycl::native::exp(-x * alpha));
|
|
919
|
+
out_glu = out_glu * (1.0f + g);
|
|
920
|
+
return out_glu;
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
template <typename T>
|
|
925
|
+
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
|
|
926
|
+
const int64_t n, const int64_t o0, const int64_t o1,
|
|
927
|
+
float alpha, float limit, sycl::nd_item<3> item_ct1) {
|
|
928
|
+
const int64_t i = int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
|
929
|
+
|
|
930
|
+
if (i >= k) {
|
|
931
|
+
return;
|
|
932
|
+
}
|
|
933
|
+
|
|
934
|
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
|
935
|
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
|
936
|
+
|
|
937
|
+
float xi = x[j0];
|
|
938
|
+
float gi = g[j1];
|
|
939
|
+
|
|
940
|
+
dst[i] = ggml_sycl_op_swiglu_oai_single(xi, gi, alpha, limit);
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
template <typename T>
|
|
944
|
+
static void swiglu_oai_sycl(const T * x,
|
|
945
|
+
const T * g,
|
|
946
|
+
T * dst,
|
|
947
|
+
const int64_t k,
|
|
948
|
+
const int64_t n,
|
|
949
|
+
const int64_t o0,
|
|
950
|
+
const int64_t o1,
|
|
951
|
+
const float alpha,
|
|
952
|
+
const float limit,
|
|
953
|
+
dpct::queue_ptr stream) {
|
|
954
|
+
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
|
|
955
|
+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
|
|
956
|
+
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
|
|
957
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
958
|
+
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
|
|
959
|
+
});
|
|
960
|
+
}
|
|
961
|
+
|
|
962
|
+
void ggml_sycl_op_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
963
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
964
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
965
|
+
void * src0_d = src0->data;
|
|
966
|
+
void * src1_d = src1 ? src1->data : src0->data;
|
|
967
|
+
const int64_t src0_o = src0->nb[1];
|
|
968
|
+
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
969
|
+
void * dst_d = dst->data;
|
|
970
|
+
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
971
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
972
|
+
|
|
973
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
974
|
+
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
|
975
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
976
|
+
|
|
977
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
978
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
979
|
+
GGML_ASSERT(src0->type == dst->type);
|
|
980
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
981
|
+
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
|
982
|
+
|
|
983
|
+
if (src1) {
|
|
984
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
985
|
+
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
|
986
|
+
GGML_ASSERT(src1->ne[0] == nc);
|
|
987
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
988
|
+
}
|
|
989
|
+
|
|
990
|
+
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
|
991
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
992
|
+
const float alpha = ggml_get_op_params_f32(dst, 2);
|
|
993
|
+
const float limit = ggml_get_op_params_f32(dst, 3);
|
|
994
|
+
|
|
995
|
+
float * src0_p = (float *) src0_d;
|
|
996
|
+
float * src1_p = (float *) src1_d;
|
|
997
|
+
|
|
998
|
+
if (!src1) {
|
|
999
|
+
src0_p += swapped ? nc : 0;
|
|
1000
|
+
src1_p += swapped ? 0 : nc;
|
|
1001
|
+
}
|
|
1002
|
+
|
|
1003
|
+
swiglu_oai_sycl(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
|
|
1004
|
+
}
|
|
1005
|
+
|
|
999
1006
|
static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1000
1007
|
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
|
1001
1008
|
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
|
@@ -1119,10 +1126,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
|
1119
1126
|
ggml_sycl_op_upscale(ctx, dst);
|
|
1120
1127
|
}
|
|
1121
1128
|
|
|
1122
|
-
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1123
|
-
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
1124
|
-
ggml_sycl_op_pad(ctx, dst);
|
|
1125
|
-
}
|
|
1126
1129
|
|
|
1127
1130
|
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1128
1131
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
@@ -1159,6 +1162,11 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
|
1159
1162
|
ggml_sycl_op_swiglu(ctx, dst);
|
|
1160
1163
|
}
|
|
1161
1164
|
|
|
1165
|
+
void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1166
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
1167
|
+
ggml_sycl_op_swiglu_oai(ctx, dst);
|
|
1168
|
+
}
|
|
1169
|
+
|
|
1162
1170
|
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1163
1171
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
1164
1172
|
ggml_sycl_op_geglu_erf(ctx, dst);
|
|
@@ -1168,3 +1176,28 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
|
1168
1176
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
1169
1177
|
ggml_sycl_op_geglu_quick(ctx, dst);
|
|
1170
1178
|
}
|
|
1179
|
+
|
|
1180
|
+
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1181
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
|
|
1182
|
+
ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1186
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
1187
|
+
ggml_sycl_op_floor(ctx, dst);
|
|
1188
|
+
}
|
|
1189
|
+
|
|
1190
|
+
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1191
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
1192
|
+
ggml_sycl_op_ceil(ctx, dst);
|
|
1193
|
+
}
|
|
1194
|
+
|
|
1195
|
+
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1196
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
1197
|
+
ggml_sycl_op_round(ctx, dst);
|
|
1198
|
+
}
|
|
1199
|
+
|
|
1200
|
+
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1201
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
1202
|
+
ggml_sycl_op_trunc(ctx, dst);
|
|
1203
|
+
}
|