whispercpp 1.3.2 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +59 -27
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +154 -35
- data/ext/sources/examples/addon.node/index.js +10 -5
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +29 -18
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +7 -4
- data/ext/sources/examples/command/command.cpp +58 -32
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +21 -17
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +193 -35
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +10 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
- data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
- data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
- data/ext/sources/examples/talk-llama/llama-context.h +68 -32
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
- data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
- data/ext/sources/examples/talk-llama/llama-model.h +87 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
- data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -17
- data/ext/sources/examples/talk-llama/llama.h +176 -151
- data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +106 -33
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +18 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +365 -21
- data/ext/sources/ggml/src/CMakeLists.txt +98 -25
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
- data/ext/sources/ggml/src/ggml-common.h +21 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
- data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
- data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
- data/ext/sources/ggml/src/ggml-impl.h +229 -175
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +117 -24
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +802 -142
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +32 -4
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +241 -215
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +57 -2
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +75 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/{tests → test}/test_params.rb +8 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +246 -191
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,12 +1,19 @@
|
|
1
1
|
#include "common.hpp"
|
2
|
+
#include "ggml-sycl/presets.hpp"
|
2
3
|
#include "ggml.h"
|
3
4
|
#include "element_wise.hpp"
|
4
5
|
|
6
|
+
#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
|
7
|
+
for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
|
8
|
+
|
9
|
+
#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
|
10
|
+
(ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
|
11
|
+
|
12
|
+
|
5
13
|
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
6
14
|
const int ne10, const int ne11, const int ne12,
|
7
|
-
const int nb1, const int nb2, int offset, const sycl::nd_item<
|
8
|
-
const int i = item_ct1
|
9
|
-
item_ct1.get_local_id(2);
|
15
|
+
const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
|
16
|
+
const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
|
10
17
|
if (i >= ne) {
|
11
18
|
return;
|
12
19
|
}
|
@@ -21,239 +28,280 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
|
21
28
|
}
|
22
29
|
}
|
23
30
|
|
31
|
+
/* Unary OP funcs */
|
24
32
|
template<typename T>
|
25
|
-
static
|
26
|
-
|
27
|
-
dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
28
|
-
}
|
33
|
+
static __dpct_inline__ T op_sgn(T x) {
|
34
|
+
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
29
35
|
}
|
30
36
|
|
31
37
|
template<typename T>
|
32
|
-
static
|
33
|
-
|
34
|
-
dst[i] = sycl::fabs(x[i]);
|
35
|
-
}
|
38
|
+
static __dpct_inline__ T op_abs(T x) {
|
39
|
+
return sycl::fabs(x);
|
36
40
|
}
|
37
41
|
|
38
42
|
template<typename T>
|
39
|
-
static
|
40
|
-
|
41
|
-
dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
|
42
|
-
}
|
43
|
+
static __dpct_inline__ T op_elu(T x) {
|
44
|
+
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
|
43
45
|
}
|
44
46
|
|
45
47
|
template<typename T>
|
46
|
-
static
|
47
|
-
const sycl::nd_item<3> &item_ct1) {
|
48
|
+
static __dpct_inline__ T op_gelu(T x) {
|
48
49
|
const T GELU_COEF_A = static_cast<T>(0.044715f);
|
49
50
|
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
|
50
|
-
|
51
|
-
|
51
|
+
return static_cast<T>(0.5f) * x *
|
52
|
+
(static_cast<T>(1.0f) +
|
53
|
+
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
54
|
+
}
|
52
55
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
+
template<typename T>
|
57
|
+
static __dpct_inline__ T op_silu(T x) {
|
58
|
+
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
59
|
+
}
|
56
60
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
+
template<typename T>
|
62
|
+
static __dpct_inline__ T op_gelu_quick(T x) {
|
63
|
+
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
64
|
+
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
|
61
65
|
}
|
62
66
|
|
63
67
|
template<typename T>
|
64
|
-
static
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
+
static __dpct_inline__ T op_gelu_erf(T x) {
|
69
|
+
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
|
70
|
+
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
|
71
|
+
}
|
68
72
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
dst[i] = x[i] / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
|
73
|
+
template<typename T>
|
74
|
+
static __dpct_inline__ T op_tanh(T x) {
|
75
|
+
return sycl::tanh(x);
|
73
76
|
}
|
74
77
|
|
75
78
|
template<typename T>
|
76
|
-
static
|
77
|
-
|
78
|
-
const float GELU_QUICK_COEF = -1.702f;
|
79
|
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
80
|
-
item_ct1.get_local_id(2);
|
81
|
-
if (i >= k) {
|
82
|
-
return;
|
83
|
-
}
|
84
|
-
dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
79
|
+
static __dpct_inline__ T op_relu(T x) {
|
80
|
+
return sycl::fmax(x, static_cast<T>(0));
|
85
81
|
}
|
86
82
|
|
87
83
|
template<typename T>
|
88
|
-
static
|
89
|
-
|
90
|
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
91
|
-
item_ct1.get_local_id(2);
|
92
|
-
if (i >= k) {
|
93
|
-
return;
|
94
|
-
}
|
95
|
-
dst[i] = sycl::tanh((x[i]));
|
84
|
+
static __dpct_inline__ T op_sigmoid(T x) {
|
85
|
+
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
96
86
|
}
|
97
87
|
|
98
88
|
template<typename T>
|
99
|
-
static
|
100
|
-
|
101
|
-
|
102
|
-
item_ct1.get_local_id(2);
|
89
|
+
static __dpct_inline__ T op_sqrt(T x) {
|
90
|
+
return sycl::sqrt(x);
|
91
|
+
}
|
103
92
|
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
dst[i] = sycl::fmax((x[i]), static_cast<T>(0));
|
93
|
+
template<typename T>
|
94
|
+
static __dpct_inline__ T op_sin(T x) {
|
95
|
+
return sycl::sin(x);
|
108
96
|
}
|
109
97
|
|
110
98
|
template<typename T>
|
111
|
-
static
|
112
|
-
|
113
|
-
|
114
|
-
item_ct1.get_local_id(2);
|
99
|
+
static __dpct_inline__ T op_cos(T x) {
|
100
|
+
return sycl::cos(x);
|
101
|
+
}
|
115
102
|
|
116
|
-
|
117
|
-
|
103
|
+
template<typename T>
|
104
|
+
static __dpct_inline__ T op_hardsigmoid(T x) {
|
105
|
+
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
106
|
+
}
|
107
|
+
|
108
|
+
template<typename T>
|
109
|
+
static __dpct_inline__ T op_hardswish(T x) {
|
110
|
+
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
111
|
+
}
|
112
|
+
|
113
|
+
template<typename T>
|
114
|
+
static __dpct_inline__ T op_exp(T x) {
|
115
|
+
return sycl::exp(x);
|
116
|
+
}
|
117
|
+
|
118
|
+
template<typename T>
|
119
|
+
static __dpct_inline__ T op_log(T x) {
|
120
|
+
if (x <= static_cast<T>(0)) {
|
121
|
+
return neg_infinity<T>();
|
118
122
|
}
|
119
|
-
|
123
|
+
return sycl::log(x);
|
120
124
|
}
|
121
125
|
|
122
126
|
template<typename T>
|
123
|
-
static
|
124
|
-
|
125
|
-
|
126
|
-
item_ct1.get_local_id(2);
|
127
|
+
static __dpct_inline__ T op_neg(T x) {
|
128
|
+
return -x;
|
129
|
+
}
|
127
130
|
|
128
|
-
|
129
|
-
|
131
|
+
template<typename T>
|
132
|
+
static __dpct_inline__ T op_step(T x) {
|
133
|
+
return (x > static_cast<T>(0.0f)) ? static_cast<T>(1.0f) : static_cast<T>(0.0f);
|
134
|
+
}
|
135
|
+
|
136
|
+
template<typename T>
|
137
|
+
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
|
138
|
+
T neg_slope_T = static_cast<T>(negative_slope);
|
139
|
+
return sycl::fmax(x, static_cast<T>(0)) +
|
140
|
+
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
141
|
+
}
|
142
|
+
|
143
|
+
template<typename T>
|
144
|
+
static __dpct_inline__ T op_sqr(T x) {
|
145
|
+
return x * x;
|
146
|
+
}
|
147
|
+
|
148
|
+
template<typename T>
|
149
|
+
static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
150
|
+
return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
|
151
|
+
}
|
152
|
+
|
153
|
+
template<typename T>
|
154
|
+
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
155
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
156
|
+
dst[i] = op_sgn(x[i]);
|
130
157
|
}
|
131
|
-
dst[i] = sycl::sqrt(x[i]);
|
132
158
|
}
|
133
159
|
|
134
160
|
template<typename T>
|
135
|
-
static void
|
136
|
-
|
137
|
-
|
138
|
-
|
161
|
+
static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
162
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
163
|
+
dst[i] = op_abs(x[i]);
|
164
|
+
}
|
165
|
+
}
|
139
166
|
|
140
|
-
|
141
|
-
|
167
|
+
template<typename T>
|
168
|
+
static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
169
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
170
|
+
dst[i] = op_elu(x[i]);
|
142
171
|
}
|
143
|
-
dst[i] = sycl::sin(x[i]);
|
144
172
|
}
|
145
173
|
|
146
174
|
template<typename T>
|
147
|
-
static void
|
148
|
-
|
149
|
-
|
150
|
-
|
175
|
+
static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
176
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
177
|
+
dst[i] = op_gelu(x[i]);
|
178
|
+
}
|
179
|
+
}
|
151
180
|
|
152
|
-
|
153
|
-
|
181
|
+
template<typename T>
|
182
|
+
static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
183
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
184
|
+
dst[i] = op_silu(x[i]);
|
154
185
|
}
|
155
|
-
dst[i] = sycl::cos(x[i]);
|
156
186
|
}
|
157
187
|
|
158
188
|
template<typename T>
|
159
|
-
static void
|
160
|
-
|
161
|
-
|
162
|
-
|
189
|
+
static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
190
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
191
|
+
dst[i] = op_gelu_quick(x[i]);
|
192
|
+
}
|
193
|
+
}
|
163
194
|
|
164
|
-
|
165
|
-
|
195
|
+
template<typename T>
|
196
|
+
static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
197
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
198
|
+
dst[i] = op_gelu_erf(x[i]);
|
199
|
+
}
|
200
|
+
}
|
201
|
+
|
202
|
+
template<typename T>
|
203
|
+
static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
204
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
205
|
+
dst[i] = op_tanh(x[i]);
|
166
206
|
}
|
167
|
-
dst[i] = sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
168
207
|
}
|
169
208
|
|
170
209
|
template<typename T>
|
171
|
-
static void
|
172
|
-
|
173
|
-
|
174
|
-
|
210
|
+
static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
211
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
212
|
+
dst[i] = op_relu(x[i]);
|
213
|
+
}
|
214
|
+
}
|
175
215
|
|
176
|
-
|
177
|
-
|
216
|
+
template<typename T>
|
217
|
+
static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
218
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
219
|
+
dst[i] = op_sigmoid(x[i]);
|
178
220
|
}
|
179
|
-
dst[i] = x[i] * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
180
221
|
}
|
181
222
|
|
182
223
|
template<typename T>
|
183
|
-
static void
|
184
|
-
|
185
|
-
|
186
|
-
|
224
|
+
static void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
225
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
226
|
+
dst[i] = op_sqrt(x[i]);
|
227
|
+
}
|
228
|
+
}
|
187
229
|
|
188
|
-
|
189
|
-
|
230
|
+
template<typename T>
|
231
|
+
static void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
232
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
233
|
+
dst[i] = op_sin(x[i]);
|
190
234
|
}
|
191
|
-
dst[i] = sycl::exp(x[i]);
|
192
235
|
}
|
193
236
|
|
194
237
|
template<typename T>
|
195
|
-
static void
|
196
|
-
|
197
|
-
|
198
|
-
|
238
|
+
static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
239
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
240
|
+
dst[i] = op_cos(x[i]);
|
241
|
+
}
|
242
|
+
}
|
199
243
|
|
200
|
-
|
201
|
-
|
244
|
+
template<typename T>
|
245
|
+
static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
246
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
247
|
+
dst[i] = op_hardsigmoid(x[i]);
|
202
248
|
}
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
249
|
+
}
|
250
|
+
|
251
|
+
template<typename T>
|
252
|
+
static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
253
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
254
|
+
dst[i] = op_hardswish(x[i]);
|
208
255
|
}
|
209
256
|
}
|
210
257
|
|
211
258
|
template<typename T>
|
212
|
-
static void
|
213
|
-
|
214
|
-
|
215
|
-
|
259
|
+
static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
260
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
261
|
+
dst[i] = op_exp(x[i]);
|
262
|
+
}
|
263
|
+
}
|
216
264
|
|
217
|
-
|
218
|
-
|
265
|
+
template<typename T>
|
266
|
+
static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
267
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
268
|
+
dst[i] = op_log(x[i]);
|
219
269
|
}
|
220
|
-
dst[i] = -x[i];
|
221
270
|
}
|
222
271
|
|
223
272
|
template<typename T>
|
224
|
-
static void
|
225
|
-
|
226
|
-
|
227
|
-
|
273
|
+
static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
274
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
275
|
+
dst[i] = op_neg(x[i]);
|
276
|
+
}
|
277
|
+
}
|
228
278
|
|
229
|
-
|
230
|
-
|
279
|
+
template<typename T>
|
280
|
+
static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
281
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
282
|
+
dst[i] = op_step(x[i]);
|
231
283
|
}
|
232
|
-
dst[i] = x[i] > static_cast<T>(0.0f);
|
233
284
|
}
|
234
285
|
|
235
286
|
template<typename T>
|
236
|
-
static void
|
237
|
-
|
238
|
-
|
239
|
-
item_ct1.get_local_id(2);
|
240
|
-
if (i >= k) {
|
241
|
-
return;
|
287
|
+
static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
|
288
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
289
|
+
dst[i] = op_leaky_relu(x[i], negative_slope);
|
242
290
|
}
|
243
|
-
dst[i] = sycl::fmax((x[i]), static_cast<T>(0)) +
|
244
|
-
sycl::fmin((x[i]), static_cast<T>(0.0f)) * negative_slope;
|
245
291
|
}
|
246
292
|
|
247
293
|
template<typename T>
|
248
|
-
static void
|
249
|
-
|
250
|
-
|
251
|
-
|
294
|
+
static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
295
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
296
|
+
dst[i] = op_sqr(x[i]);
|
297
|
+
}
|
298
|
+
}
|
252
299
|
|
253
|
-
|
254
|
-
|
300
|
+
template<typename T>
|
301
|
+
static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
|
302
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
303
|
+
dst[i] = op_clamp(x[i], min_val, max_val);
|
255
304
|
}
|
256
|
-
dst[i] = x[i] * x[i];
|
257
305
|
}
|
258
306
|
|
259
307
|
template<typename T>
|
@@ -272,10 +320,10 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
|
272
320
|
int i12 = (index / (ne10 * ne11)) % ne12;
|
273
321
|
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
|
274
322
|
|
275
|
-
int i00 = i10 / sf0;
|
276
|
-
int i01 = i11 / sf1;
|
277
|
-
int i02 = i12 / sf2;
|
278
|
-
int i03 = i13 / sf3;
|
323
|
+
int i00 = static_cast<int>(i10 / sf0);
|
324
|
+
int i01 = static_cast<int>(i11 / sf1);
|
325
|
+
int i02 = static_cast<int>(i12 / sf2);
|
326
|
+
int i03 = static_cast<int>(i13 / sf3);
|
279
327
|
|
280
328
|
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
281
329
|
}
|
@@ -283,8 +331,7 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
|
283
331
|
template <typename T>
|
284
332
|
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
285
333
|
const sycl::nd_item<3> &item_ct1) {
|
286
|
-
int nidx = item_ct1
|
287
|
-
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
334
|
+
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
288
335
|
if (nidx >= ne0) {
|
289
336
|
return;
|
290
337
|
}
|
@@ -301,285 +348,72 @@ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne
|
|
301
348
|
}
|
302
349
|
}
|
303
350
|
|
304
|
-
|
305
351
|
template<typename T>
|
306
352
|
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
|
307
|
-
const sycl::nd_item<
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
if (i >= k) {
|
312
|
-
return;
|
353
|
+
const sycl::nd_item<1> &item_ct1) {
|
354
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
355
|
+
dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
|
313
356
|
}
|
314
|
-
|
315
|
-
dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
|
316
|
-
}
|
317
|
-
|
318
|
-
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
319
|
-
const int n_elements, const int ne10, const int ne11,
|
320
|
-
const int ne12, const int nb1, const int nb2,
|
321
|
-
const int offset, queue_ptr stream) {
|
322
|
-
int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
323
|
-
stream->parallel_for(
|
324
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
325
|
-
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
326
|
-
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
327
|
-
[=](sycl::nd_item<3> item_ct1) {
|
328
|
-
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
|
329
|
-
item_ct1);
|
330
|
-
});
|
331
|
-
}
|
332
|
-
|
333
|
-
template<typename T>
|
334
|
-
static void gelu_sycl(const T *x, T *dst, const int k,
|
335
|
-
queue_ptr stream) {
|
336
|
-
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
337
|
-
stream->parallel_for(
|
338
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
339
|
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
340
|
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
341
|
-
[=](sycl::nd_item<3> item_ct1) {
|
342
|
-
gelu(x, dst, k, item_ct1);
|
343
|
-
});
|
344
|
-
}
|
345
|
-
|
346
|
-
template<typename T>
|
347
|
-
static void silu_sycl(const T *x, T *dst, const int k,
|
348
|
-
queue_ptr stream) {
|
349
|
-
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
350
|
-
stream->parallel_for(
|
351
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
352
|
-
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
|
353
|
-
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
|
354
|
-
[=](sycl::nd_item<3> item_ct1) {
|
355
|
-
silu(x, dst, k, item_ct1);
|
356
|
-
});
|
357
|
-
}
|
358
|
-
|
359
|
-
template<typename T>
|
360
|
-
static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
|
361
|
-
// hard code for now
|
362
|
-
const int num_blocks = ceil_div(k, 256);
|
363
|
-
stream->parallel_for(
|
364
|
-
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
|
365
|
-
sgn(x, dst, k, item_ct1);
|
366
|
-
});
|
367
|
-
}
|
368
|
-
|
369
|
-
template<typename T>
|
370
|
-
static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
|
371
|
-
// hard code for now
|
372
|
-
const int num_blocks = ceil_div(k, 256);
|
373
|
-
stream->parallel_for(
|
374
|
-
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
|
375
|
-
abs_op(x, dst, k, item_ct1);
|
376
|
-
});
|
377
|
-
}
|
378
|
-
|
379
|
-
|
380
|
-
template<typename T>
|
381
|
-
static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
|
382
|
-
// hard code for now
|
383
|
-
const int num_blocks = ceil_div(k, 256);
|
384
|
-
stream->parallel_for(
|
385
|
-
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
|
386
|
-
elu_op(x, dst, k, item_ct1);
|
387
|
-
});
|
388
|
-
}
|
389
|
-
|
390
|
-
template<typename T>
|
391
|
-
static void gelu_quick_sycl(const T *x, T *dst, const int k,
|
392
|
-
queue_ptr stream) {
|
393
|
-
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
394
|
-
stream->parallel_for(
|
395
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
396
|
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
397
|
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
398
|
-
[=](sycl::nd_item<3> item_ct1) {
|
399
|
-
gelu_quick(x, dst, k, item_ct1);
|
400
|
-
});
|
401
|
-
}
|
402
|
-
|
403
|
-
template<typename T>
|
404
|
-
static void tanh_sycl(const T *x, T *dst, const int k,
|
405
|
-
queue_ptr stream) {
|
406
|
-
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
407
|
-
stream->parallel_for(
|
408
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
409
|
-
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
|
410
|
-
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
|
411
|
-
[=](sycl::nd_item<3> item_ct1) {
|
412
|
-
tanh(x, dst, k, item_ct1);
|
413
|
-
});
|
414
|
-
}
|
415
|
-
|
416
|
-
template<typename T>
|
417
|
-
static void relu_sycl(const T *x, T *dst, const int k,
|
418
|
-
queue_ptr stream) {
|
419
|
-
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
420
|
-
stream->parallel_for(
|
421
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
422
|
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
423
|
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
424
|
-
[=](sycl::nd_item<3> item_ct1) {
|
425
|
-
relu(x, dst, k, item_ct1);
|
426
|
-
});
|
427
|
-
}
|
428
|
-
|
429
|
-
template<typename T>
|
430
|
-
static void hardsigmoid_sycl(const T *x, T *dst, const int k,
|
431
|
-
queue_ptr stream) {
|
432
|
-
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
433
|
-
stream->parallel_for(
|
434
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
435
|
-
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
|
436
|
-
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
437
|
-
[=](sycl::nd_item<3> item_ct1) {
|
438
|
-
hardsigmoid(x, dst, k, item_ct1);
|
439
|
-
});
|
440
|
-
}
|
441
|
-
|
442
|
-
template<typename T>
|
443
|
-
static void hardswish_sycl(const T *x, T *dst, const int k,
|
444
|
-
queue_ptr stream) {
|
445
|
-
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
446
|
-
stream->parallel_for(
|
447
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
448
|
-
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
|
449
|
-
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
|
450
|
-
[=](sycl::nd_item<3> item_ct1) {
|
451
|
-
hardswish(x, dst, k, item_ct1);
|
452
|
-
});
|
453
|
-
}
|
454
|
-
|
455
|
-
template<typename T>
|
456
|
-
static void exp_sycl(const T *x, T *dst, const int k,
|
457
|
-
queue_ptr stream) {
|
458
|
-
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
459
|
-
stream->parallel_for(
|
460
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
461
|
-
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
462
|
-
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
463
|
-
[=](sycl::nd_item<3> item_ct1) {
|
464
|
-
exp(x, dst, k, item_ct1);
|
465
|
-
});
|
466
|
-
}
|
467
|
-
|
468
|
-
template<typename T>
|
469
|
-
static void log_sycl(const T *x, T *dst, const int k,
|
470
|
-
queue_ptr stream) {
|
471
|
-
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
472
|
-
stream->parallel_for(
|
473
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
474
|
-
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
475
|
-
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
476
|
-
[=](sycl::nd_item<3> item_ct1) {
|
477
|
-
log(x, dst, k, item_ct1);
|
478
|
-
});
|
479
|
-
}
|
480
|
-
|
481
|
-
template<typename T>
|
482
|
-
static void neg_sycl(const T *x, T *dst, const int k,
|
483
|
-
queue_ptr stream) {
|
484
|
-
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
485
|
-
stream->parallel_for(
|
486
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
487
|
-
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
488
|
-
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
489
|
-
[=](sycl::nd_item<3> item_ct1) {
|
490
|
-
neg(x, dst, k, item_ct1);
|
491
|
-
});
|
492
|
-
}
|
493
|
-
|
494
|
-
template<typename T>
|
495
|
-
static void step_sycl(const T *x, T *dst, const int k,
|
496
|
-
queue_ptr stream) {
|
497
|
-
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
498
|
-
stream->parallel_for(
|
499
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
500
|
-
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
501
|
-
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
502
|
-
[=](sycl::nd_item<3> item_ct1) {
|
503
|
-
step(x, dst, k, item_ct1);
|
504
|
-
});
|
505
357
|
}
|
506
358
|
|
507
359
|
template<typename T>
|
508
|
-
static void
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
|
515
|
-
[=](sycl::nd_item<3> item_ct1) {
|
516
|
-
sigmoid(x, dst, k, item_ct1);
|
517
|
-
});
|
360
|
+
static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
361
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
362
|
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
363
|
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
364
|
+
dst[i] = op_gelu(x[j0]) * g[j1];
|
365
|
+
}
|
518
366
|
}
|
519
367
|
|
520
368
|
template<typename T>
|
521
|
-
static void
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
|
528
|
-
[=](sycl::nd_item<3> item_ct1) {
|
529
|
-
sqrt(x, dst, k, item_ct1);
|
530
|
-
});
|
369
|
+
static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
370
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
371
|
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
372
|
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
373
|
+
dst[i] = op_relu(x[j0]) * g[j1];
|
374
|
+
}
|
531
375
|
}
|
532
376
|
|
533
377
|
template<typename T>
|
534
|
-
static void
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
|
541
|
-
[=](sycl::nd_item<3> item_ct1) {
|
542
|
-
sin(x, dst, k, item_ct1);
|
543
|
-
});
|
378
|
+
static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
379
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
380
|
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
381
|
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
382
|
+
dst[i] = op_silu(x[j0]) * g[j1];
|
383
|
+
}
|
544
384
|
}
|
545
385
|
|
546
386
|
template<typename T>
|
547
|
-
static void
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
|
554
|
-
[=](sycl::nd_item<3> item_ct1) {
|
555
|
-
cos(x, dst, k, item_ct1);
|
556
|
-
});
|
387
|
+
static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
388
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
389
|
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
390
|
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
391
|
+
dst[i] = op_gelu_erf(x[j0]) * g[j1];
|
392
|
+
}
|
557
393
|
}
|
558
394
|
|
559
395
|
template<typename T>
|
560
|
-
static void
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
567
|
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
568
|
-
[=](sycl::nd_item<3> item_ct1) {
|
569
|
-
leaky_relu(x, dst, k, negative_slope, item_ct1);
|
570
|
-
});
|
396
|
+
static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
397
|
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
398
|
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
399
|
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
400
|
+
dst[i] = op_gelu_quick(x[j0]) * g[j1];
|
401
|
+
}
|
571
402
|
}
|
572
403
|
|
573
|
-
|
574
|
-
static void
|
575
|
-
|
576
|
-
|
404
|
+
namespace ggml_sycl_detail {
|
405
|
+
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
406
|
+
const int n_elements, const int ne10, const int ne11,
|
407
|
+
const int ne12, const int nb1, const int nb2,
|
408
|
+
const int offset, queue_ptr stream) {
|
409
|
+
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
|
577
410
|
stream->parallel_for(
|
578
|
-
sycl::nd_range<
|
579
|
-
sycl::range<
|
580
|
-
sycl::range<
|
581
|
-
[=](sycl::nd_item<
|
582
|
-
|
411
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
|
412
|
+
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
|
413
|
+
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
|
414
|
+
[=](sycl::nd_item<1> item_ct1) {
|
415
|
+
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
|
416
|
+
item_ct1);
|
583
417
|
});
|
584
418
|
}
|
585
419
|
|
@@ -589,11 +423,10 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
|
589
423
|
const int ne12, const int ne13, const float sf0, const float sf1,
|
590
424
|
const float sf2, const float sf3, queue_ptr stream) {
|
591
425
|
int dst_size = ne10 * ne11 * ne12 * ne13;
|
592
|
-
int num_blocks = (dst_size
|
426
|
+
int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
|
593
427
|
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
594
428
|
stream->parallel_for(
|
595
|
-
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
|
596
|
-
[=](sycl::nd_item<1> item_ct1) {
|
429
|
+
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
597
430
|
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
|
598
431
|
});
|
599
432
|
}
|
@@ -602,35 +435,19 @@ template<typename T>
|
|
602
435
|
static void pad_sycl(const T *x, T *dst, const int ne00,
|
603
436
|
const int ne01, const int ne02, const int ne0,
|
604
437
|
const int ne1, const int ne2, queue_ptr stream) {
|
605
|
-
int num_blocks = (ne0
|
438
|
+
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
606
439
|
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
607
440
|
stream->parallel_for(
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
pad(x, dst, ne0, ne00, ne01, ne02, item_ct1);
|
612
|
-
});
|
613
|
-
}
|
614
|
-
|
615
|
-
template<typename T>
|
616
|
-
static void clamp_sycl(const T *x, T *dst, const float min,
|
617
|
-
const float max, const int k,
|
618
|
-
queue_ptr stream) {
|
619
|
-
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
|
620
|
-
stream->parallel_for(
|
621
|
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
622
|
-
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
|
623
|
-
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
|
624
|
-
[=](sycl::nd_item<3> item_ct1) {
|
625
|
-
clamp(x, dst, min, max, k, item_ct1);
|
626
|
-
});
|
441
|
+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
442
|
+
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
443
|
+
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
627
444
|
}
|
628
445
|
|
629
|
-
|
446
|
+
template<typename KernelInvoker, typename... Args>
|
447
|
+
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
630
448
|
#if defined (GGML_SYCL_F16)
|
631
449
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
632
450
|
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
633
|
-
|
634
451
|
#else
|
635
452
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
636
453
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
@@ -643,14 +460,14 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
643
460
|
case GGML_TYPE_F16:
|
644
461
|
{
|
645
462
|
auto data_pts = cast_data<sycl::half>(dst);
|
646
|
-
|
463
|
+
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
647
464
|
break;
|
648
465
|
}
|
649
466
|
#endif
|
650
467
|
case GGML_TYPE_F32:
|
651
468
|
{
|
652
469
|
auto data_pts = cast_data<float>(dst);
|
653
|
-
|
470
|
+
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
654
471
|
break;
|
655
472
|
}
|
656
473
|
default:
|
@@ -658,11 +475,11 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
658
475
|
}
|
659
476
|
}
|
660
477
|
|
661
|
-
|
478
|
+
template<typename KernelInvoker, typename... Args>
|
479
|
+
static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
662
480
|
#if defined (GGML_SYCL_F16)
|
663
481
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
664
482
|
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
665
|
-
|
666
483
|
#else
|
667
484
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
668
485
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
@@ -670,19 +487,66 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
670
487
|
GGML_ASSERT(dst->src[0]->type == dst->type);
|
671
488
|
dpct::queue_ptr main_stream = ctx.stream();
|
672
489
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
490
|
+
const ggml_tensor * src0 = dst->src[0];
|
491
|
+
const ggml_tensor * src1 = dst->src[1];
|
492
|
+
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;;
|
493
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
494
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
|
495
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
496
|
+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
497
|
+
void * src0_d = src0->data;
|
498
|
+
void * src1_d = src1 ? src1->data : src0->data;
|
499
|
+
const int64_t src0_o = src0->nb[1];
|
500
|
+
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
501
|
+
void * dst_d = dst->data;
|
502
|
+
if (src1) {
|
503
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
504
|
+
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
505
|
+
GGML_ASSERT(src1->ne[0] == nc);
|
506
|
+
GGML_ASSERT(src0->type == src1->type);
|
507
|
+
}
|
673
508
|
switch (dst->type) {
|
674
509
|
#if defined (GGML_SYCL_F16)
|
675
510
|
case GGML_TYPE_F16:
|
676
511
|
{
|
677
|
-
|
678
|
-
|
512
|
+
sycl::half * src0_p = (sycl::half *) src0_d;
|
513
|
+
sycl::half * src1_p = (sycl::half *) src1_d;
|
514
|
+
|
515
|
+
if (!src1) {
|
516
|
+
src0_p += swapped ? nc : 0;
|
517
|
+
src1_p += swapped ? 0 : nc;
|
518
|
+
}
|
519
|
+
kernel_invoker(src0_p,
|
520
|
+
src1_p,
|
521
|
+
(sycl::half *) dst_d,
|
522
|
+
ggml_nelements(dst),
|
523
|
+
nc,
|
524
|
+
src0_o / sizeof(sycl::half),
|
525
|
+
src1_o / sizeof(sycl::half),
|
526
|
+
main_stream,
|
527
|
+
std::forward<Args>(args)...);
|
679
528
|
break;
|
680
529
|
}
|
681
530
|
#endif
|
682
531
|
case GGML_TYPE_F32:
|
683
532
|
{
|
684
|
-
|
685
|
-
|
533
|
+
float * src0_p = (float *) src0_d;
|
534
|
+
float * src1_p = (float *) src1_d;
|
535
|
+
|
536
|
+
if (!src1) {
|
537
|
+
src0_p += swapped ? nc : 0;
|
538
|
+
src1_p += swapped ? 0 : nc;
|
539
|
+
}
|
540
|
+
|
541
|
+
kernel_invoker(src0_p,
|
542
|
+
src1_p,
|
543
|
+
(float *) dst_d,
|
544
|
+
ggml_nelements(dst),
|
545
|
+
nc,
|
546
|
+
src0_o / sizeof(float),
|
547
|
+
src1_o / sizeof(float),
|
548
|
+
main_stream,
|
549
|
+
std::forward<Args>(args)...);
|
686
550
|
break;
|
687
551
|
}
|
688
552
|
default:
|
@@ -690,32 +554,41 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
690
554
|
}
|
691
555
|
}
|
692
556
|
|
693
|
-
|
694
|
-
inline void
|
557
|
+
template<typename KernelInvoker, typename... Args>
|
558
|
+
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
695
559
|
#if defined (GGML_SYCL_F16)
|
696
560
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
697
561
|
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
698
|
-
|
699
562
|
#else
|
700
563
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
701
564
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
702
565
|
#endif
|
703
566
|
GGML_ASSERT(dst->src[0]->type == dst->type);
|
567
|
+
|
704
568
|
dpct::queue_ptr main_stream = ctx.stream();
|
705
569
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
570
|
+
|
571
|
+
const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
|
572
|
+
const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
|
573
|
+
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
|
574
|
+
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
|
706
575
|
switch (dst->type) {
|
707
576
|
#if defined (GGML_SYCL_F16)
|
708
577
|
case GGML_TYPE_F16:
|
709
578
|
{
|
710
579
|
auto data_pts = cast_data<sycl::half>(dst);
|
711
|
-
|
580
|
+
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
|
581
|
+
(int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
|
582
|
+
main_stream, std::forward<Args>(args)...);
|
712
583
|
break;
|
713
584
|
}
|
714
585
|
#endif
|
715
586
|
case GGML_TYPE_F32:
|
716
587
|
{
|
717
588
|
auto data_pts = cast_data<float>(dst);
|
718
|
-
|
589
|
+
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
|
590
|
+
(int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
|
591
|
+
main_stream, std::forward<Args>(args)...);
|
719
592
|
break;
|
720
593
|
}
|
721
594
|
default:
|
@@ -723,7 +596,8 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
723
596
|
}
|
724
597
|
}
|
725
598
|
|
726
|
-
|
599
|
+
template<typename KernelInvoker, typename... Args>
|
600
|
+
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
727
601
|
#if defined (GGML_SYCL_F16)
|
728
602
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
729
603
|
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
@@ -732,6 +606,7 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
|
732
606
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
733
607
|
#endif
|
734
608
|
GGML_ASSERT(dst->src[0]->type == dst->type);
|
609
|
+
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
735
610
|
dpct::queue_ptr main_stream = ctx.stream();
|
736
611
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
737
612
|
switch (dst->type) {
|
@@ -739,14 +614,16 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
|
739
614
|
case GGML_TYPE_F16:
|
740
615
|
{
|
741
616
|
auto data_pts = cast_data<sycl::half>(dst);
|
742
|
-
|
617
|
+
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
618
|
+
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
743
619
|
break;
|
744
620
|
}
|
745
621
|
#endif
|
746
622
|
case GGML_TYPE_F32:
|
747
623
|
{
|
748
624
|
auto data_pts = cast_data<float>(dst);
|
749
|
-
|
625
|
+
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
626
|
+
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
750
627
|
break;
|
751
628
|
}
|
752
629
|
default:
|
@@ -754,623 +631,320 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
|
754
631
|
}
|
755
632
|
}
|
756
633
|
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
773
|
-
gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
774
|
-
break;
|
775
|
-
}
|
776
|
-
#endif
|
777
|
-
case GGML_TYPE_F32:
|
778
|
-
{
|
779
|
-
auto data_pts = cast_data<float>(dst);
|
780
|
-
gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
781
|
-
break;
|
782
|
-
}
|
783
|
-
default:
|
784
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
785
|
-
}
|
634
|
+
} // namespace ggml_sycl_detail
|
635
|
+
|
636
|
+
|
637
|
+
|
638
|
+
static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
639
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
640
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
641
|
+
const int num_blocks = ceil_div(k_elements, 256);
|
642
|
+
stream->parallel_for(
|
643
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
644
|
+
sycl::range<1>(256)),
|
645
|
+
[=](sycl::nd_item<1> item_ct1) {
|
646
|
+
unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
|
647
|
+
});
|
648
|
+
});
|
786
649
|
}
|
787
650
|
|
788
|
-
inline void
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
switch (dst->type) {
|
800
|
-
#if defined (GGML_SYCL_F16)
|
801
|
-
case GGML_TYPE_F16:
|
802
|
-
{
|
803
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
804
|
-
gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
805
|
-
break;
|
806
|
-
}
|
807
|
-
#endif
|
808
|
-
case GGML_TYPE_F32:
|
809
|
-
{
|
810
|
-
auto data_pts = cast_data<float>(dst);
|
811
|
-
gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
812
|
-
break;
|
813
|
-
}
|
814
|
-
default:
|
815
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
816
|
-
}
|
817
|
-
}
|
818
|
-
|
819
|
-
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
820
|
-
#if defined (GGML_SYCL_F16)
|
821
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
822
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
823
|
-
#else
|
824
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
825
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
826
|
-
#endif
|
827
|
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
828
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
829
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
830
|
-
switch (dst->type) {
|
831
|
-
#if defined (GGML_SYCL_F16)
|
832
|
-
case GGML_TYPE_F16:
|
833
|
-
{
|
834
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
835
|
-
tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
836
|
-
break;
|
837
|
-
}
|
838
|
-
#endif
|
839
|
-
case GGML_TYPE_F32:
|
840
|
-
{
|
841
|
-
auto data_pts = cast_data<float>(dst);
|
842
|
-
tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
843
|
-
break;
|
844
|
-
}
|
845
|
-
default:
|
846
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
847
|
-
}
|
848
|
-
}
|
849
|
-
|
850
|
-
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
851
|
-
#if defined (GGML_SYCL_F16)
|
852
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
853
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
854
|
-
#else
|
855
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
856
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
857
|
-
#endif
|
858
|
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
859
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
860
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
861
|
-
|
862
|
-
switch (dst->type) {
|
863
|
-
#if defined (GGML_SYCL_F16)
|
864
|
-
case GGML_TYPE_F16:
|
865
|
-
{
|
866
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
867
|
-
relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
868
|
-
break;
|
869
|
-
}
|
870
|
-
#endif
|
871
|
-
case GGML_TYPE_F32:
|
872
|
-
{
|
873
|
-
auto data_pts = cast_data<float>(dst);
|
874
|
-
relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
875
|
-
break;
|
876
|
-
}
|
877
|
-
default:
|
878
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
879
|
-
}
|
880
|
-
}
|
881
|
-
|
882
|
-
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
883
|
-
#if defined (GGML_SYCL_F16)
|
884
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
885
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
886
|
-
#else
|
887
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
888
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
889
|
-
#endif
|
890
|
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
891
|
-
|
892
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
893
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
894
|
-
|
895
|
-
switch (dst->type) {
|
896
|
-
#if defined (GGML_SYCL_F16)
|
897
|
-
case GGML_TYPE_F16:
|
898
|
-
{
|
899
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
900
|
-
hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
901
|
-
break;
|
902
|
-
}
|
903
|
-
#endif
|
904
|
-
case GGML_TYPE_F32:
|
905
|
-
{
|
906
|
-
auto data_pts = cast_data<float>(dst);
|
907
|
-
hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
908
|
-
break;
|
909
|
-
}
|
910
|
-
default:
|
911
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
912
|
-
}
|
913
|
-
}
|
914
|
-
|
915
|
-
inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
916
|
-
#if defined (GGML_SYCL_F16)
|
917
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
918
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
919
|
-
#else
|
920
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
921
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
922
|
-
#endif
|
923
|
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
924
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
925
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
926
|
-
switch (dst->type) {
|
927
|
-
#if defined (GGML_SYCL_F16)
|
928
|
-
case GGML_TYPE_F16:
|
929
|
-
{
|
930
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
931
|
-
hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
932
|
-
break;
|
933
|
-
}
|
934
|
-
#endif
|
935
|
-
case GGML_TYPE_F32:
|
936
|
-
{
|
937
|
-
auto data_pts = cast_data<float>(dst);
|
938
|
-
hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
939
|
-
break;
|
940
|
-
}
|
941
|
-
default:
|
942
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
943
|
-
}
|
944
|
-
}
|
945
|
-
|
946
|
-
inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
947
|
-
#if defined (GGML_SYCL_F16)
|
948
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
949
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
950
|
-
#else
|
951
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
952
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
953
|
-
#endif
|
954
|
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
955
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
956
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
957
|
-
switch (dst->type) {
|
958
|
-
#if defined (GGML_SYCL_F16)
|
959
|
-
case GGML_TYPE_F16:
|
960
|
-
{
|
961
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
962
|
-
exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
963
|
-
break;
|
964
|
-
}
|
965
|
-
#endif
|
966
|
-
case GGML_TYPE_F32:
|
967
|
-
{
|
968
|
-
auto data_pts = cast_data<float>(dst);
|
969
|
-
exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
970
|
-
break;
|
971
|
-
}
|
972
|
-
default:
|
973
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
974
|
-
}
|
975
|
-
}
|
976
|
-
|
977
|
-
inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
978
|
-
#if defined (GGML_SYCL_F16)
|
979
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
980
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
981
|
-
#else
|
982
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
983
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
984
|
-
#endif
|
985
|
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
986
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
987
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
988
|
-
switch (dst->type) {
|
989
|
-
#if defined (GGML_SYCL_F16)
|
990
|
-
case GGML_TYPE_F16:
|
991
|
-
{
|
992
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
993
|
-
log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
994
|
-
break;
|
995
|
-
}
|
996
|
-
#endif
|
997
|
-
case GGML_TYPE_F32:
|
998
|
-
{
|
999
|
-
auto data_pts = cast_data<float>(dst);
|
1000
|
-
log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1001
|
-
break;
|
1002
|
-
}
|
1003
|
-
default:
|
1004
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1005
|
-
}
|
1006
|
-
}
|
1007
|
-
|
1008
|
-
inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1009
|
-
#if defined (GGML_SYCL_F16)
|
1010
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
1011
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
1012
|
-
#else
|
1013
|
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
1014
|
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
1015
|
-
#endif
|
1016
|
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
1017
|
-
dpct::queue_ptr main_stream = ctx.stream();
|
1018
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
1019
|
-
switch (dst->type) {
|
1020
|
-
#if defined (GGML_SYCL_F16)
|
1021
|
-
case GGML_TYPE_F16:
|
1022
|
-
{
|
1023
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
1024
|
-
sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1025
|
-
break;
|
1026
|
-
}
|
1027
|
-
#endif
|
1028
|
-
case GGML_TYPE_F32:
|
1029
|
-
{
|
1030
|
-
auto data_pts = cast_data<float>(dst);
|
1031
|
-
sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1032
|
-
break;
|
1033
|
-
}
|
1034
|
-
default:
|
1035
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1036
|
-
}
|
651
|
+
static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
652
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
653
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
654
|
+
const int num_blocks = ceil_div(k_elements, 256);
|
655
|
+
stream->parallel_for(
|
656
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
657
|
+
sycl::range<1>(256)),
|
658
|
+
[=](sycl::nd_item<1> item_ct1) {
|
659
|
+
unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
|
660
|
+
});
|
661
|
+
});
|
1037
662
|
}
|
1038
663
|
|
1039
|
-
inline void
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
1051
|
-
switch (dst->type) {
|
1052
|
-
#if defined (GGML_SYCL_F16)
|
1053
|
-
case GGML_TYPE_F16:
|
1054
|
-
{
|
1055
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
1056
|
-
sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1057
|
-
break;
|
1058
|
-
}
|
1059
|
-
#endif
|
1060
|
-
case GGML_TYPE_F32:
|
1061
|
-
{
|
1062
|
-
auto data_pts = cast_data<float>(dst);
|
1063
|
-
sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1064
|
-
break;
|
1065
|
-
}
|
1066
|
-
default:
|
1067
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1068
|
-
}
|
664
|
+
static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
665
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
666
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
667
|
+
const int num_blocks = ceil_div(k_elements, 256);
|
668
|
+
stream->parallel_for(
|
669
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
670
|
+
sycl::range<1>(256)),
|
671
|
+
[=](sycl::nd_item<1> item_ct1) {
|
672
|
+
unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
|
673
|
+
});
|
674
|
+
});
|
1069
675
|
}
|
1070
676
|
|
1071
|
-
inline void
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
switch (dst->type) {
|
1083
|
-
#if defined (GGML_SYCL_F16)
|
1084
|
-
case GGML_TYPE_F16:
|
1085
|
-
{
|
1086
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
1087
|
-
sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1088
|
-
break;
|
1089
|
-
}
|
1090
|
-
#endif
|
1091
|
-
case GGML_TYPE_F32:
|
1092
|
-
{
|
1093
|
-
auto data_pts = cast_data<float>(dst);
|
1094
|
-
sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1095
|
-
break;
|
1096
|
-
}
|
1097
|
-
default:
|
1098
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1099
|
-
}
|
677
|
+
static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
678
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
679
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
680
|
+
const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
|
681
|
+
stream->parallel_for(
|
682
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
|
683
|
+
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
684
|
+
[=](sycl::nd_item<1> item_ct1) {
|
685
|
+
unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
|
686
|
+
});
|
687
|
+
});
|
1100
688
|
}
|
1101
689
|
|
1102
|
-
inline void
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
switch (dst->type) {
|
1114
|
-
#if defined (GGML_SYCL_F16)
|
1115
|
-
case GGML_TYPE_F16:
|
1116
|
-
{
|
1117
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
1118
|
-
cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1119
|
-
break;
|
1120
|
-
}
|
1121
|
-
#endif
|
1122
|
-
case GGML_TYPE_F32:
|
1123
|
-
{
|
1124
|
-
auto data_pts = cast_data<float>(dst);
|
1125
|
-
cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1126
|
-
break;
|
1127
|
-
}
|
1128
|
-
default:
|
1129
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1130
|
-
}
|
690
|
+
static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
691
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
692
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
693
|
+
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
694
|
+
stream->parallel_for(
|
695
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
696
|
+
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
697
|
+
[=](sycl::nd_item<1> item_ct1) {
|
698
|
+
unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
|
699
|
+
});
|
700
|
+
});
|
1131
701
|
}
|
1132
702
|
|
1133
|
-
inline void
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
switch (dst->type) {
|
1145
|
-
#if defined (GGML_SYCL_F16)
|
1146
|
-
case GGML_TYPE_F16:
|
1147
|
-
{
|
1148
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
1149
|
-
step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1150
|
-
break;
|
1151
|
-
}
|
1152
|
-
#endif
|
1153
|
-
case GGML_TYPE_F32:
|
1154
|
-
{
|
1155
|
-
auto data_pts = cast_data<float>(dst);
|
1156
|
-
step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1157
|
-
break;
|
1158
|
-
}
|
1159
|
-
default:
|
1160
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1161
|
-
}
|
703
|
+
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
704
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
705
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
706
|
+
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
707
|
+
stream->parallel_for(
|
708
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
709
|
+
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
710
|
+
[=](sycl::nd_item<1> item_ct1) {
|
711
|
+
unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
|
712
|
+
});
|
713
|
+
});
|
1162
714
|
}
|
1163
715
|
|
1164
|
-
inline void
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
switch (dst->type) {
|
1176
|
-
#if defined (GGML_SYCL_F16)
|
1177
|
-
case GGML_TYPE_F16:
|
1178
|
-
{
|
1179
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
1180
|
-
neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1181
|
-
break;
|
1182
|
-
}
|
1183
|
-
#endif
|
1184
|
-
case GGML_TYPE_F32:
|
1185
|
-
{
|
1186
|
-
auto data_pts = cast_data<float>(dst);
|
1187
|
-
neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1188
|
-
break;
|
1189
|
-
}
|
1190
|
-
default:
|
1191
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1192
|
-
}
|
716
|
+
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
717
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
718
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
719
|
+
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
720
|
+
stream->parallel_for(
|
721
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
722
|
+
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
723
|
+
[=](sycl::nd_item<1> item_ct1) {
|
724
|
+
unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
|
725
|
+
});
|
726
|
+
});
|
1193
727
|
}
|
1194
728
|
|
1195
|
-
inline void
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
729
|
+
static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
730
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
731
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
732
|
+
const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
|
733
|
+
stream->parallel_for(
|
734
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
|
735
|
+
sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
|
736
|
+
[=](sycl::nd_item<1> item_ct1) {
|
737
|
+
unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
|
738
|
+
});
|
739
|
+
});
|
740
|
+
}
|
1203
741
|
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
break;
|
1216
|
-
}
|
1217
|
-
#endif
|
1218
|
-
case GGML_TYPE_F32:
|
1219
|
-
{
|
1220
|
-
auto data_pts = cast_data<float>(dst);
|
1221
|
-
leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
|
1222
|
-
break;
|
1223
|
-
}
|
1224
|
-
default:
|
1225
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1226
|
-
}
|
742
|
+
static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
743
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
744
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
745
|
+
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
746
|
+
stream->parallel_for(
|
747
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
748
|
+
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
749
|
+
[=](sycl::nd_item<1> item_ct1) {
|
750
|
+
unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
|
751
|
+
});
|
752
|
+
});
|
1227
753
|
}
|
1228
754
|
|
1229
|
-
inline void
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
switch (dst->type) {
|
1241
|
-
#if defined (GGML_SYCL_F16)
|
1242
|
-
case GGML_TYPE_F16:
|
1243
|
-
{
|
1244
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
1245
|
-
sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1246
|
-
break;
|
1247
|
-
}
|
1248
|
-
#endif
|
1249
|
-
case GGML_TYPE_F32:
|
1250
|
-
{
|
1251
|
-
auto data_pts = cast_data<float>(dst);
|
1252
|
-
sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
1253
|
-
break;
|
1254
|
-
}
|
1255
|
-
default:
|
1256
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1257
|
-
}
|
755
|
+
static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
756
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
757
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
758
|
+
const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
|
759
|
+
stream->parallel_for(
|
760
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
|
761
|
+
sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
762
|
+
[=](sycl::nd_item<1> item_ct1) {
|
763
|
+
unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
764
|
+
});
|
765
|
+
});
|
1258
766
|
}
|
1259
767
|
|
1260
|
-
inline void
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
768
|
+
static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
769
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
770
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
771
|
+
const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
|
772
|
+
stream->parallel_for(
|
773
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
|
774
|
+
sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
|
775
|
+
[=](sycl::nd_item<1> item_ct1) {
|
776
|
+
unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
|
777
|
+
});
|
778
|
+
});
|
779
|
+
}
|
1269
780
|
|
1270
|
-
|
1271
|
-
|
781
|
+
static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
782
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
783
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
784
|
+
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
|
785
|
+
stream->parallel_for(
|
786
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
787
|
+
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
788
|
+
[=](sycl::nd_item<1> item_ct1) {
|
789
|
+
unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
|
790
|
+
});
|
791
|
+
});
|
792
|
+
}
|
1272
793
|
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
main_stream);
|
1285
|
-
break;
|
1286
|
-
}
|
1287
|
-
#endif
|
1288
|
-
case GGML_TYPE_F32:
|
1289
|
-
{
|
1290
|
-
auto data_pts = cast_data<float>(dst);
|
1291
|
-
upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
|
1292
|
-
dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
|
1293
|
-
main_stream);
|
1294
|
-
break;
|
1295
|
-
}
|
1296
|
-
default:
|
1297
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1298
|
-
}
|
794
|
+
static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
795
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
796
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
797
|
+
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
|
798
|
+
stream->parallel_for(
|
799
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
800
|
+
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
801
|
+
[=](sycl::nd_item<1> item_ct1) {
|
802
|
+
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
|
803
|
+
});
|
804
|
+
});
|
1299
805
|
}
|
1300
806
|
|
1301
|
-
inline void
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
1313
|
-
switch (dst->type) {
|
1314
|
-
#if defined (GGML_SYCL_F16)
|
1315
|
-
case GGML_TYPE_F16:
|
1316
|
-
{
|
1317
|
-
auto data_pts = cast_data<sycl::half>(dst);
|
1318
|
-
pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
|
1319
|
-
dst->ne[1], dst->ne[2], main_stream);
|
1320
|
-
break;
|
1321
|
-
}
|
1322
|
-
#endif
|
1323
|
-
case GGML_TYPE_F32:
|
1324
|
-
{
|
1325
|
-
auto data_pts = cast_data<float>(dst);
|
1326
|
-
pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
|
1327
|
-
dst->ne[1], dst->ne[2], main_stream);
|
1328
|
-
break;
|
1329
|
-
}
|
1330
|
-
default:
|
1331
|
-
GGML_ABORT("GGML tensor type not supported!\n");
|
1332
|
-
}
|
807
|
+
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
808
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
809
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
810
|
+
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
|
811
|
+
stream->parallel_for(
|
812
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
813
|
+
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
814
|
+
[=](sycl::nd_item<1> item_ct1) {
|
815
|
+
unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
|
816
|
+
});
|
817
|
+
});
|
1333
818
|
}
|
1334
819
|
|
1335
|
-
inline void
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
820
|
+
static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
821
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
822
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
823
|
+
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
|
824
|
+
stream->parallel_for(
|
825
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
826
|
+
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
827
|
+
[=](sycl::nd_item<1> item_ct1) {
|
828
|
+
unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
|
829
|
+
});
|
830
|
+
});
|
831
|
+
}
|
1340
832
|
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
833
|
+
static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
834
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
835
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
836
|
+
const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
|
837
|
+
stream->parallel_for(
|
838
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
|
839
|
+
sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
|
840
|
+
[=](sycl::nd_item<1> item_ct1) {
|
841
|
+
unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
842
|
+
});
|
843
|
+
});
|
844
|
+
}
|
1351
845
|
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
846
|
+
static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
847
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
848
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
849
|
+
const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
|
850
|
+
stream->parallel_for(
|
851
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
852
|
+
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
853
|
+
[=](sycl::nd_item<1> item_ct1) {
|
854
|
+
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
|
855
|
+
});
|
856
|
+
});
|
857
|
+
}
|
858
|
+
|
859
|
+
static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
860
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
861
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
862
|
+
const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
|
863
|
+
stream->parallel_for(
|
864
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
865
|
+
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
866
|
+
[=](sycl::nd_item<1> item_ct1) {
|
867
|
+
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
|
868
|
+
});
|
869
|
+
});
|
870
|
+
}
|
871
|
+
|
872
|
+
static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
873
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
874
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
875
|
+
const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
|
876
|
+
stream->parallel_for(
|
877
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
878
|
+
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
879
|
+
[=](sycl::nd_item<1> item_ct1) {
|
880
|
+
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
|
881
|
+
});
|
882
|
+
});
|
1370
883
|
}
|
1371
884
|
|
1372
|
-
inline void
|
885
|
+
static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
886
|
+
float negative_slope;
|
887
|
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
888
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
889
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
|
890
|
+
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
891
|
+
stream->parallel_for(
|
892
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
893
|
+
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
894
|
+
[=](sycl::nd_item<1> item_ct1) {
|
895
|
+
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
|
896
|
+
});
|
897
|
+
}, negative_slope);
|
898
|
+
}
|
899
|
+
|
900
|
+
static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
901
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
902
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
903
|
+
const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
|
904
|
+
stream->parallel_for(
|
905
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
906
|
+
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
907
|
+
[=](sycl::nd_item<1> item_ct1) {
|
908
|
+
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
|
909
|
+
});
|
910
|
+
});
|
911
|
+
}
|
912
|
+
|
913
|
+
static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
914
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
|
915
|
+
[](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
|
916
|
+
int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
|
917
|
+
queue_ptr stream) {
|
918
|
+
ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
|
919
|
+
});
|
920
|
+
}
|
921
|
+
|
922
|
+
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
923
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
|
924
|
+
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
|
925
|
+
queue_ptr stream) {
|
926
|
+
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
927
|
+
});
|
928
|
+
}
|
1373
929
|
|
930
|
+
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
931
|
+
float min_val;
|
932
|
+
float max_val;
|
933
|
+
memcpy(&min_val, dst->op_params, sizeof(float));
|
934
|
+
memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float));
|
935
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
936
|
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
|
937
|
+
const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
|
938
|
+
stream->parallel_for(
|
939
|
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
940
|
+
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
941
|
+
[=](sycl::nd_item<1> item_ct1) {
|
942
|
+
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
|
943
|
+
});
|
944
|
+
}, min_val, max_val);
|
945
|
+
}
|
946
|
+
|
947
|
+
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
1374
948
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
1375
949
|
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
1376
950
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
@@ -1386,7 +960,62 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
1386
960
|
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
1387
961
|
int offset = dst->op_params[3] / 4; // offset in bytes
|
1388
962
|
|
1389
|
-
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
963
|
+
ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
964
|
+
}
|
965
|
+
|
966
|
+
static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
967
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
968
|
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
969
|
+
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
970
|
+
main_stream->parallel_for(
|
971
|
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
972
|
+
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
973
|
+
});
|
974
|
+
});
|
975
|
+
}
|
976
|
+
|
977
|
+
static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
978
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
979
|
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
980
|
+
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
981
|
+
main_stream->parallel_for(
|
982
|
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
983
|
+
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
984
|
+
});
|
985
|
+
});
|
986
|
+
}
|
987
|
+
|
988
|
+
static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
989
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
990
|
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
991
|
+
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
992
|
+
main_stream->parallel_for(
|
993
|
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
994
|
+
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
995
|
+
});
|
996
|
+
});
|
997
|
+
}
|
998
|
+
|
999
|
+
static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1000
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
1001
|
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
1002
|
+
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
1003
|
+
main_stream->parallel_for(
|
1004
|
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
1005
|
+
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
1006
|
+
});
|
1007
|
+
});
|
1008
|
+
}
|
1009
|
+
|
1010
|
+
static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1011
|
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
1012
|
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
1013
|
+
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
1014
|
+
main_stream->parallel_for(
|
1015
|
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
1016
|
+
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
1017
|
+
});
|
1018
|
+
});
|
1390
1019
|
}
|
1391
1020
|
|
1392
1021
|
|
@@ -1425,6 +1054,11 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1425
1054
|
ggml_sycl_op_gelu_quick(ctx, dst);
|
1426
1055
|
}
|
1427
1056
|
|
1057
|
+
void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1058
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1059
|
+
ggml_sycl_op_gelu_erf(ctx, dst);
|
1060
|
+
}
|
1061
|
+
|
1428
1062
|
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1429
1063
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1430
1064
|
ggml_sycl_op_tanh(ctx, dst);
|
@@ -1509,3 +1143,28 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
1509
1143
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1510
1144
|
ggml_sycl_op_elu(ctx, dst);
|
1511
1145
|
}
|
1146
|
+
|
1147
|
+
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1148
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1149
|
+
ggml_sycl_op_geglu(ctx, dst);
|
1150
|
+
}
|
1151
|
+
|
1152
|
+
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1153
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1154
|
+
ggml_sycl_op_reglu(ctx, dst);
|
1155
|
+
}
|
1156
|
+
|
1157
|
+
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1158
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1159
|
+
ggml_sycl_op_swiglu(ctx, dst);
|
1160
|
+
}
|
1161
|
+
|
1162
|
+
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1163
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1164
|
+
ggml_sycl_op_geglu_erf(ctx, dst);
|
1165
|
+
}
|
1166
|
+
|
1167
|
+
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1168
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1169
|
+
ggml_sycl_op_geglu_quick(ctx, dst);
|
1170
|
+
}
|