whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -41,8 +41,10 @@
|
|
41
41
|
#include "ggml-sycl/element_wise.hpp"
|
42
42
|
#include "ggml-sycl/presets.hpp"
|
43
43
|
#include "ggml-sycl/gemm.hpp"
|
44
|
+
#include "ggml-sycl/set_rows.hpp"
|
44
45
|
#include "ggml-sycl/sycl_hw.hpp"
|
45
46
|
#include "ggml-sycl/getrows.hpp"
|
47
|
+
#include "ggml-sycl/quantize.hpp"
|
46
48
|
#include "ggml.h"
|
47
49
|
|
48
50
|
static bool g_sycl_loaded = false;
|
@@ -83,7 +85,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
|
83
85
|
|
84
86
|
info.devices[i].cc =
|
85
87
|
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
86
|
-
info.devices[i].opt_feature.reorder =
|
88
|
+
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
87
89
|
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
88
90
|
}
|
89
91
|
|
@@ -1372,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
|
|
1372
1374
|
|
1373
1375
|
|
1374
1376
|
|
1375
|
-
template<int QUANT_BLOCK_TILE>
|
1376
|
-
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
|
1377
|
-
const sycl::nd_item<3> &item_ct1) {
|
1378
|
-
const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
1379
|
-
item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
|
1380
|
-
|
1381
|
-
if (ix >= kx_padded) {
|
1382
|
-
return;
|
1383
|
-
}
|
1384
|
-
|
1385
|
-
const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
1386
|
-
item_ct1.get_local_id(1);
|
1387
|
-
|
1388
|
-
const int i_padded = iy*kx_padded + ix;
|
1389
|
-
|
1390
|
-
block_q8_1 * y = (block_q8_1 *) vy;
|
1391
|
-
|
1392
|
-
const int ib = i_padded / QK8_1; // block index
|
1393
|
-
const int iqs = i_padded % QK8_1; // quant index
|
1394
|
-
typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
|
1395
|
-
typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
|
1396
|
-
TC zeros;
|
1397
|
-
TQ qzeros;
|
1398
|
-
#pragma unroll
|
1399
|
-
for (int i = 0; i < QUANT_BLOCK_TILE; i++)
|
1400
|
-
{
|
1401
|
-
zeros[i] = 0.f;
|
1402
|
-
qzeros[i] = 0;
|
1403
|
-
}
|
1404
|
-
const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
|
1405
|
-
float sum = xi[0];
|
1406
|
-
float amax = sycl::fabs(xi[0]);
|
1407
|
-
#pragma unroll
|
1408
|
-
for (int i = 1; i < QUANT_BLOCK_TILE; i++)
|
1409
|
-
{
|
1410
|
-
sum += xi[i];
|
1411
|
-
amax = sycl::fmax(sycl::fabs(xi[i]), amax);
|
1412
|
-
}
|
1413
|
-
sum = warp_reduce_sum(sum, item_ct1);
|
1414
|
-
amax = warp_reduce_max(amax, item_ct1);
|
1415
|
-
|
1416
|
-
const float d = amax / 127;
|
1417
|
-
TQ q = qzeros;
|
1418
|
-
if (amax != 0.0f)
|
1419
|
-
{
|
1420
|
-
#pragma unroll
|
1421
|
-
for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
|
1422
|
-
q[i] = sycl::round(xi[i] / d);
|
1423
|
-
}
|
1424
|
-
}
|
1425
|
-
|
1426
|
-
*(TQ *)&y[ib].qs[iqs] = q;
|
1427
|
-
|
1428
|
-
if (iqs > 0) {
|
1429
|
-
return;
|
1430
|
-
}
|
1431
|
-
|
1432
|
-
reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
|
1433
|
-
reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
|
1434
|
-
}
|
1435
|
-
|
1436
|
-
template <int ElementsPerWI>
|
1437
|
-
static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
|
1438
|
-
const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
|
1439
|
-
/*
|
1440
|
-
Quantizes and reorders the resultant q8 tensor in a per row fashion
|
1441
|
-
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
|
1442
|
-
*/
|
1443
|
-
|
1444
|
-
auto subgroup_id = it.get_group(0);
|
1445
|
-
auto wi_id = it.get_local_id(0);
|
1446
|
-
|
1447
|
-
const int num_blocks_per_row = kx / QK8_1;
|
1448
|
-
auto row = subgroup_id / num_blocks_per_row;
|
1449
|
-
auto col = subgroup_id % num_blocks_per_row;
|
1450
|
-
|
1451
|
-
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
|
1452
|
-
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
|
1453
|
-
|
1454
|
-
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
|
1455
|
-
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
|
1456
|
-
|
1457
|
-
sycl::vec<float, ElementsPerWI> wi_f32_vals;
|
1458
|
-
sycl::vec<int8_t, ElementsPerWI> quantized_values;
|
1459
|
-
|
1460
|
-
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
|
1461
|
-
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
|
1462
|
-
|
1463
|
-
float sum = 0.0f;
|
1464
|
-
float amax = 0.0f;
|
1465
|
-
|
1466
|
-
#pragma unroll(ElementsPerWI)
|
1467
|
-
for (int i = 0; i < ElementsPerWI; i++) {
|
1468
|
-
sum += wi_f32_vals[i];
|
1469
|
-
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
|
1470
|
-
quantized_values[i] = 0;
|
1471
|
-
}
|
1472
|
-
sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
|
1473
|
-
amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
|
1474
|
-
float d = amax == 0 ? 1 : amax / 127;
|
1475
|
-
|
1476
|
-
#pragma unroll(ElementsPerWI)
|
1477
|
-
for (int i = 0; i < ElementsPerWI; i++) {
|
1478
|
-
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
|
1479
|
-
}
|
1480
|
-
|
1481
|
-
d = amax == 0 ? 0 : d;
|
1482
|
-
|
1483
|
-
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
|
1484
|
-
if (wi_id == 0) {
|
1485
|
-
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
|
1486
|
-
}
|
1487
|
-
}
|
1488
|
-
|
1489
1377
|
static void mul_mat_p021_f16_f32(
|
1490
1378
|
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
|
1491
1379
|
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
|
@@ -1545,7 +1433,7 @@ static void mul_mat_p021_f16_f32(
|
|
1545
1433
|
|
1546
1434
|
static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
1547
1435
|
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
|
1548
|
-
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
|
1436
|
+
const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
|
1549
1437
|
const sycl::nd_item<3> &item_ct1) {
|
1550
1438
|
|
1551
1439
|
const sycl::half *x = (const sycl::half *)vx;
|
@@ -1556,7 +1444,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
1556
1444
|
item_ct1.get_local_id(0);
|
1557
1445
|
const int channel_x = channel / channel_x_divisor;
|
1558
1446
|
|
1559
|
-
const int nrows_y = ncols_x;
|
1560
1447
|
const int nrows_dst = nrows_x;
|
1561
1448
|
const int row_dst = row_x;
|
1562
1449
|
|
@@ -1575,7 +1462,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
1575
1462
|
const int row_y = col_x;
|
1576
1463
|
|
1577
1464
|
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
|
1578
|
-
const int iy = channel*
|
1465
|
+
const int iy = channel * channel_stride_y + row_y;
|
1579
1466
|
|
1580
1467
|
const float xi =
|
1581
1468
|
sycl::vec<sycl::half, 1>(x[ix])
|
@@ -1695,7 +1582,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
|
|
1695
1582
|
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
1696
1583
|
}
|
1697
1584
|
|
1698
|
-
static void scale_f32(const float * x, float * dst, const float scale, const int k,
|
1585
|
+
static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
|
1699
1586
|
const sycl::nd_item<3> &item_ct1) {
|
1700
1587
|
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
1701
1588
|
item_ct1.get_local_id(2);
|
@@ -1704,7 +1591,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
|
|
1704
1591
|
return;
|
1705
1592
|
}
|
1706
1593
|
|
1707
|
-
dst[i] = scale * x[i];
|
1594
|
+
dst[i] = scale * x[i] + bias;
|
1708
1595
|
}
|
1709
1596
|
|
1710
1597
|
|
@@ -1770,32 +1657,6 @@ static void pool2d_nchw_kernel(
|
|
1770
1657
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
1771
1658
|
}
|
1772
1659
|
|
1773
|
-
static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
|
1774
|
-
bool reorder_q8_tensor, queue_ptr stream) {
|
1775
|
-
if (reorder_q8_tensor) {
|
1776
|
-
auto local_range = std::size_t(WARP_SIZE);
|
1777
|
-
auto num_quant_blocks = ky * (kx / QK8_1);
|
1778
|
-
auto global_range = num_quant_blocks * local_range;
|
1779
|
-
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
|
1780
|
-
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
1781
|
-
quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
|
1782
|
-
});
|
1783
|
-
} else {
|
1784
|
-
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
|
1785
|
-
const sycl::range<3> num_blocks(1, ky, block_num_x);
|
1786
|
-
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
|
1787
|
-
static_assert(QK8_1 % WARP_SIZE == 0);
|
1788
|
-
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
|
1789
|
-
{
|
1790
|
-
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
1791
|
-
|
1792
|
-
stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
|
1793
|
-
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
1794
|
-
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
1795
|
-
});
|
1796
|
-
}
|
1797
|
-
}
|
1798
|
-
}
|
1799
1660
|
|
1800
1661
|
static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
1801
1662
|
float *dst, const int ncols_x,
|
@@ -1822,7 +1683,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
|
1822
1683
|
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
1823
1684
|
const void *vx, const float *y, float *dst, const int ncols_x,
|
1824
1685
|
const int nrows_x, const int row_stride_x, const int nchannels_x,
|
1825
|
-
const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
|
1686
|
+
const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
|
1826
1687
|
|
1827
1688
|
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
|
1828
1689
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
@@ -1834,7 +1695,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
|
1834
1695
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1835
1696
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
1836
1697
|
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
1837
|
-
row_stride_x, channel_stride_x,
|
1698
|
+
row_stride_x, channel_stride_x, channel_stride_y,
|
1838
1699
|
nchannels_y / nchannels_x, item_ct1);
|
1839
1700
|
});
|
1840
1701
|
}
|
@@ -1842,7 +1703,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
|
1842
1703
|
|
1843
1704
|
|
1844
1705
|
|
1845
|
-
static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
1706
|
+
static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
|
1846
1707
|
const int k, queue_ptr stream) {
|
1847
1708
|
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
|
1848
1709
|
stream->parallel_for(
|
@@ -1850,7 +1711,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
|
1850
1711
|
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
|
1851
1712
|
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
|
1852
1713
|
[=](sycl::nd_item<3> item_ct1) {
|
1853
|
-
scale_f32(x, dst, scale, k, item_ct1);
|
1714
|
+
scale_f32(x, dst, scale, bias, k, item_ct1);
|
1854
1715
|
});
|
1855
1716
|
}
|
1856
1717
|
|
@@ -1885,12 +1746,13 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
1885
1746
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
1886
1747
|
|
1887
1748
|
if (order == GGML_SORT_ORDER_ASC) {
|
1888
|
-
|
1749
|
+
stream->submit([&](sycl::handler &cgh) {
|
1889
1750
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
1890
1751
|
sycl::range<1>(shared_mem), cgh);
|
1891
1752
|
|
1892
|
-
|
1893
|
-
|
1753
|
+
cgh.parallel_for(
|
1754
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1755
|
+
[=](sycl::nd_item<3> item_ct1) {
|
1894
1756
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
1895
1757
|
x, dst, ncols, ncols_pad, item_ct1,
|
1896
1758
|
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
@@ -1898,12 +1760,13 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
1898
1760
|
});
|
1899
1761
|
});
|
1900
1762
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
1901
|
-
|
1763
|
+
stream->submit([&](sycl::handler &cgh) {
|
1902
1764
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
1903
1765
|
sycl::range<1>(shared_mem), cgh);
|
1904
1766
|
|
1905
|
-
|
1906
|
-
|
1767
|
+
cgh.parallel_for(
|
1768
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1769
|
+
[=](sycl::nd_item<3> item_ct1) {
|
1907
1770
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
1908
1771
|
x, dst, ncols, ncols_pad, item_ct1,
|
1909
1772
|
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
@@ -1921,47 +1784,50 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
1921
1784
|
const sycl::range<3> block_nums(1, nrows, 1);
|
1922
1785
|
const size_t shared_mem = 256 * sizeof(float);
|
1923
1786
|
|
1924
|
-
|
1787
|
+
stream->submit([&](sycl::handler &cgh) {
|
1925
1788
|
sycl::local_accessor<float, 1> shared_data(
|
1926
1789
|
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
1927
1790
|
sycl::local_accessor<int, 1> shared_indices(
|
1928
1791
|
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
1929
1792
|
|
1930
|
-
|
1931
|
-
|
1932
|
-
|
1933
|
-
|
1934
|
-
|
1935
|
-
|
1936
|
-
|
1937
|
-
|
1938
|
-
|
1939
|
-
|
1940
|
-
|
1941
|
-
|
1793
|
+
cgh.parallel_for(
|
1794
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1795
|
+
[=](sycl::nd_item<3> item_ct1) {
|
1796
|
+
const int tid = item_ct1.get_local_id(2);
|
1797
|
+
const int row = item_ct1.get_global_id(1);
|
1798
|
+
|
1799
|
+
float max_val = -INFINITY;
|
1800
|
+
int max_idx = -1;
|
1801
|
+
|
1802
|
+
for (int col = tid; col < ncols; col += 256) {
|
1803
|
+
float val = x[row * ncols + col];
|
1804
|
+
if (val > max_val) {
|
1805
|
+
max_val = val;
|
1806
|
+
max_idx = col;
|
1807
|
+
}
|
1942
1808
|
}
|
1943
|
-
}
|
1944
1809
|
|
1945
|
-
|
1946
|
-
|
1947
|
-
|
1810
|
+
shared_data[tid] = max_val;
|
1811
|
+
shared_indices[tid] = max_idx;
|
1812
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
1948
1813
|
|
1949
|
-
|
1950
|
-
|
1951
|
-
|
1952
|
-
|
1953
|
-
|
1954
|
-
|
1955
|
-
|
1814
|
+
for (int stride = 256/2; stride > 0; stride >>= 1) {
|
1815
|
+
if (tid < stride) {
|
1816
|
+
float val1 = shared_data[tid];
|
1817
|
+
float val2 = shared_data[tid + stride];
|
1818
|
+
if (val2 > val1) {
|
1819
|
+
shared_data[tid] = val2;
|
1820
|
+
shared_indices[tid] = shared_indices[tid + stride];
|
1821
|
+
}
|
1956
1822
|
}
|
1823
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
1957
1824
|
}
|
1958
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
1959
|
-
}
|
1960
1825
|
|
1961
|
-
|
1962
|
-
|
1963
|
-
|
1964
|
-
|
1826
|
+
|
1827
|
+
if (tid == 0) {
|
1828
|
+
dst[row] = shared_indices[0];
|
1829
|
+
}
|
1830
|
+
});
|
1965
1831
|
});
|
1966
1832
|
}
|
1967
1833
|
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
@@ -2123,8 +1989,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
2123
1989
|
|
2124
1990
|
#if GGML_SYCL_DNNL
|
2125
1991
|
if (!g_ggml_sycl_disable_dnn) {
|
2126
|
-
|
2127
|
-
|
1992
|
+
DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
|
1993
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
2128
1994
|
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
2129
1995
|
}
|
2130
1996
|
else
|
@@ -2170,8 +2036,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
2170
2036
|
|
2171
2037
|
#if GGML_SYCL_DNNL
|
2172
2038
|
if (!g_ggml_sycl_disable_dnn) {
|
2173
|
-
DnnlGemmWrapper::row_gemm(ctx,
|
2174
|
-
DnnlGemmWrapper::to_dt<float>(),
|
2039
|
+
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
|
2040
|
+
DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
2175
2041
|
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
2176
2042
|
}
|
2177
2043
|
else
|
@@ -2319,9 +2185,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|
2319
2185
|
float * dst_dd = static_cast<float *>(dst->data);
|
2320
2186
|
|
2321
2187
|
float scale;
|
2322
|
-
|
2188
|
+
float bias;
|
2189
|
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
2190
|
+
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
|
2323
2191
|
|
2324
|
-
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
|
2192
|
+
scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
|
2325
2193
|
/*
|
2326
2194
|
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
|
2327
2195
|
error codes. The call was replaced with 0. You need to rewrite this code.
|
@@ -2370,10 +2238,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
|
2370
2238
|
peer_access_enabled = enable_peer_access;
|
2371
2239
|
}
|
2372
2240
|
|
2241
|
+
template <template <int> typename quantize_f>
|
2373
2242
|
static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
2374
2243
|
const ggml_tensor *src1, ggml_tensor *dst,
|
2375
|
-
ggml_sycl_op_mul_mat_t op
|
2376
|
-
const bool convert_src1_to_q8_1) try {
|
2244
|
+
ggml_sycl_op_mul_mat_t op) try {
|
2377
2245
|
|
2378
2246
|
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
|
2379
2247
|
|
@@ -2468,6 +2336,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2468
2336
|
}
|
2469
2337
|
}
|
2470
2338
|
|
2339
|
+
constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
|
2340
|
+
no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
|
2471
2341
|
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
|
2472
2342
|
if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
|
2473
2343
|
continue;
|
@@ -2493,20 +2363,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2493
2363
|
dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
|
2494
2364
|
}
|
2495
2365
|
|
2496
|
-
if (
|
2366
|
+
if constexpr(quantize_enabled) {
|
2497
2367
|
dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
|
2498
2368
|
|
2499
2369
|
if (src1_on_device && src1_is_contiguous) {
|
2500
|
-
bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
|
2501
2370
|
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
|
2502
2371
|
/*num_src=*/2, " : converting src1 to Q8_1");
|
2503
|
-
|
2504
|
-
|
2505
|
-
|
2506
|
-
|
2507
|
-
|
2508
|
-
|
2509
|
-
|
2372
|
+
try {
|
2373
|
+
quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
|
2374
|
+
} catch (sycl::exception const &exc) {
|
2375
|
+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
|
2376
|
+
<< ", line:" << __LINE__ << std::endl;
|
2377
|
+
std::exit(1);
|
2378
|
+
}
|
2510
2379
|
}
|
2511
2380
|
}
|
2512
2381
|
|
@@ -2522,11 +2391,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2522
2391
|
// here an event is recorded that signals that the main device has finished calculating the input data
|
2523
2392
|
if (split && used_devices > 1) {
|
2524
2393
|
ggml_sycl_set_device(ctx.device);
|
2525
|
-
/*
|
2526
|
-
DPCT1024:91: The original code returned the error code that was further
|
2527
|
-
consumed by the program logic. This original code was replaced with 0.
|
2528
|
-
You may need to rewrite the program logic consuming the error code.
|
2529
|
-
*/
|
2530
2394
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
2531
2395
|
*src0_extra->events[ctx.device][0] =
|
2532
2396
|
ctx.stream()->ext_oneapi_submit_barrier()));
|
@@ -2550,11 +2414,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2550
2414
|
|
2551
2415
|
// wait for main GPU data if necessary
|
2552
2416
|
if (split && (i != ctx.device || is != 0)) {
|
2553
|
-
/*
|
2554
|
-
DPCT1009:163: SYCL uses exceptions to report errors and does not
|
2555
|
-
use the error codes. The original code was commented out and a
|
2556
|
-
warning string was inserted. You need to rewrite this code.
|
2557
|
-
*/
|
2558
2417
|
SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
|
2559
2418
|
{*src0_extra->events[ctx.device][0]})));
|
2560
2419
|
}
|
@@ -2580,39 +2439,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2580
2439
|
// copy src0, src1 to device if necessary
|
2581
2440
|
if (src1_is_contiguous) {
|
2582
2441
|
if (i != ctx.device) {
|
2583
|
-
if (
|
2442
|
+
if constexpr (quantize_enabled) {
|
2584
2443
|
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
|
2585
|
-
|
2586
|
-
|
2587
|
-
|
2588
|
-
|
2444
|
+
SYCL_CHECK(
|
2445
|
+
CHECK_TRY_ERROR(stream
|
2446
|
+
->memcpy(src1_ddq_i, src1_ddq_i_source,
|
2447
|
+
src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
|
2448
|
+
.wait()));
|
2589
2449
|
} else {
|
2590
|
-
|
2591
2450
|
float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
|
2592
|
-
src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
|
2451
|
+
src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
|
2593
2452
|
|
2594
|
-
SYCL_CHECK(
|
2595
|
-
src1_ddf_i, src1_ddf_i_source,
|
2596
|
-
|
2453
|
+
SYCL_CHECK(
|
2454
|
+
CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
|
2455
|
+
src1_ncols * ne10 * sizeof(float))));
|
2597
2456
|
}
|
2598
2457
|
}
|
2599
|
-
} else if (src1_on_device && !src1_is_contiguous) {
|
2600
|
-
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
|
2601
|
-
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
|
2602
2458
|
} else {
|
2603
|
-
|
2604
|
-
|
2459
|
+
if (src1_on_device) {
|
2460
|
+
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
|
2461
|
+
src1_col_0 + src1_ncols, stream));
|
2462
|
+
} else {
|
2463
|
+
GGML_ABORT("src1 is non-contiguous and not on device");
|
2464
|
+
}
|
2605
2465
|
|
2606
|
-
|
2607
|
-
|
2608
|
-
|
2609
|
-
|
2610
|
-
|
2611
|
-
|
2612
|
-
|
2613
|
-
|
2614
|
-
|
2615
|
-
|
2466
|
+
if constexpr (quantize_enabled) {
|
2467
|
+
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
|
2468
|
+
/*num_src=*/2, " : converting src1 to Q8_1");
|
2469
|
+
try {
|
2470
|
+
quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
|
2471
|
+
src1_padded_col_size, stream);
|
2472
|
+
} catch (const sycl::exception & exc) {
|
2473
|
+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
|
2474
|
+
<< "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
2475
|
+
std::exit(1);
|
2476
|
+
}
|
2477
|
+
}
|
2616
2478
|
}
|
2617
2479
|
|
2618
2480
|
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
|
@@ -2624,12 +2486,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2624
2486
|
// do the computation
|
2625
2487
|
SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
|
2626
2488
|
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
|
2627
|
-
/*
|
2628
|
-
DPCT1010:93: SYCL uses exceptions to report errors and does not
|
2629
|
-
use the error codes. The call was replaced with 0. You need to
|
2630
|
-
rewrite this code.
|
2631
|
-
*/
|
2632
|
-
SYCL_CHECK(0);
|
2633
2489
|
|
2634
2490
|
// copy dst to host or other device if necessary
|
2635
2491
|
if (!dst_on_device) {
|
@@ -2660,12 +2516,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
2660
2516
|
|
2661
2517
|
// add event for the main device to wait on until other device is done
|
2662
2518
|
if (split && (i != ctx.device || is != 0)) {
|
2663
|
-
/*
|
2664
|
-
DPCT1024:94: The original code returned the error code that
|
2665
|
-
was further consumed by the program logic. This original
|
2666
|
-
code was replaced with 0. You may need to rewrite the
|
2667
|
-
program logic consuming the error code.
|
2668
|
-
*/
|
2669
2519
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
2670
2520
|
*src0_extra->events[i][is] =
|
2671
2521
|
stream->ext_oneapi_submit_barrier()));
|
@@ -2764,6 +2614,8 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
2764
2614
|
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
2765
2615
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
2766
2616
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
2617
|
+
GGML_ASSERT(src1->ne[1] == 1);
|
2618
|
+
GGML_ASSERT(src1->ne[3] == 1);
|
2767
2619
|
|
2768
2620
|
const int64_t ne00 = src0->ne[0];
|
2769
2621
|
const int64_t ne01 = src0->ne[1];
|
@@ -2773,6 +2625,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
2773
2625
|
const int64_t nb02 = src0->nb[2];
|
2774
2626
|
|
2775
2627
|
const int64_t ne12 = src1->ne[2];
|
2628
|
+
const int64_t nb11 = src1->nb[1];
|
2776
2629
|
|
2777
2630
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
2778
2631
|
queue_ptr main_stream = ctx.stream();
|
@@ -2783,8 +2636,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
2783
2636
|
|
2784
2637
|
const int64_t row_stride_x = nb01 / sizeof(sycl::half);
|
2785
2638
|
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
|
2639
|
+
const int64_t channel_stride_y = nb11 / sizeof(float);
|
2786
2640
|
|
2787
|
-
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
2641
|
+
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
|
2788
2642
|
}
|
2789
2643
|
catch (sycl::exception const &exc) {
|
2790
2644
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
@@ -2838,8 +2692,11 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
2838
2692
|
float * dst_ddf = static_cast<float *>(dst->data);
|
2839
2693
|
|
2840
2694
|
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
2695
|
+
const size_t type_size_src0 = ggml_type_size(src0->type);
|
2841
2696
|
const size_t type_size_src1 = ggml_type_size(src1->type);
|
2842
|
-
|
2697
|
+
|
2698
|
+
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
|
2699
|
+
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
|
2843
2700
|
|
2844
2701
|
// SRC1 strides
|
2845
2702
|
int64_t s11 = nb11 / type_size_src1;
|
@@ -2851,16 +2708,47 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
2851
2708
|
if (src1->type != GGML_TYPE_F16) {
|
2852
2709
|
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
|
2853
2710
|
" : converting src1 to fp16");
|
2854
|
-
|
2855
|
-
|
2711
|
+
|
2712
|
+
// iterate tensor dims and find the slowest moving dim and stride
|
2713
|
+
int last_dim=0;
|
2714
|
+
int last_str=0;
|
2715
|
+
size_t largest_str=0;
|
2716
|
+
for(int i = 0; i< 4; i++){
|
2717
|
+
// last stride is always the largest
|
2718
|
+
if(src1->nb[i] == largest_str){
|
2719
|
+
if(src1->ne[last_dim] == 1){
|
2720
|
+
last_str = i;
|
2721
|
+
last_dim = i;
|
2722
|
+
}
|
2723
|
+
}
|
2724
|
+
if(src1->nb[i] > largest_str){
|
2725
|
+
largest_str = src1->nb[i];
|
2726
|
+
last_str = i;
|
2727
|
+
last_dim = i;
|
2728
|
+
}
|
2729
|
+
|
2730
|
+
}
|
2731
|
+
#if GGML_SYCL_DNNL
|
2732
|
+
// oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
|
2733
|
+
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
|
2734
|
+
src1_f16_alloc.alloc(ne_src1);
|
2735
|
+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
2736
|
+
GGML_ASSERT(to_fp16_sycl != nullptr);
|
2737
|
+
to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
|
2738
|
+
# else
|
2856
2739
|
const int64_t ne_src1 = ggml_nelements(src1);
|
2857
2740
|
src1_f16_alloc.alloc(ne_src1);
|
2741
|
+
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
2742
|
+
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
2858
2743
|
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
2744
|
+
#endif
|
2859
2745
|
|
2860
2746
|
src1_f16 = src1_f16_alloc.get();
|
2861
2747
|
s11 = ne10;
|
2862
2748
|
s12 = ne11 * s11;
|
2863
2749
|
s13 = ne12 * s12;
|
2750
|
+
|
2751
|
+
is_src1_cont_2 = true;
|
2864
2752
|
}
|
2865
2753
|
|
2866
2754
|
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
@@ -2889,48 +2777,115 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
2889
2777
|
|
2890
2778
|
#if GGML_SYCL_DNNL
|
2891
2779
|
if (!g_ggml_sycl_disable_dnn) {
|
2892
|
-
|
2893
|
-
|
2894
|
-
|
2895
|
-
|
2896
|
-
|
2897
|
-
|
2898
|
-
|
2899
|
-
|
2900
|
-
|
2901
|
-
|
2902
|
-
|
2903
|
-
|
2904
|
-
|
2905
|
-
|
2906
|
-
|
2907
|
-
|
2908
|
-
|
2909
|
-
|
2910
|
-
|
2780
|
+
int64_t str_a0 = nb00 / type_size_src0;
|
2781
|
+
int64_t str_a1 = nb01 / type_size_src0;
|
2782
|
+
int64_t str_a2 = nb02 / type_size_src0;
|
2783
|
+
|
2784
|
+
int64_t str_b0 = nb10 / type_size_src1;
|
2785
|
+
int64_t str_b1 = nb11 / type_size_src1;
|
2786
|
+
int64_t str_b2 = nb12 / type_size_src1;
|
2787
|
+
|
2788
|
+
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
|
2789
|
+
const sycl::half *src1, float *dst,
|
2790
|
+
int64_t a0, int64_t a1, int64_t batcha,
|
2791
|
+
int64_t /*b0*/, int64_t b1, int64_t batchb,
|
2792
|
+
int64_t sa0, int64_t sa1, int64_t sa2,
|
2793
|
+
int64_t sb0, int64_t sb1, int64_t sb2,
|
2794
|
+
int64_t sd2) {
|
2795
|
+
bool supported_broadcast = batchb == batcha ? true
|
2796
|
+
: batchb == 1 || batcha == 1 ? true
|
2797
|
+
: false;
|
2798
|
+
if (supported_broadcast) {
|
2799
|
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
|
2800
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
|
2801
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
|
2802
|
+
DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
|
2803
|
+
} else {
|
2804
|
+
// iterate over batches from smaller set of matrices (matrix 0)
|
2805
|
+
int64_t batches0 = batcha;
|
2806
|
+
int64_t batches1 = batchb;
|
2807
|
+
|
2808
|
+
if (batches0 > batches1) {
|
2809
|
+
int64_t num_mul_mats = batches1;
|
2810
|
+
int64_t sub_batch = batches0 / num_mul_mats;
|
2811
|
+
// src0 is batched and bigger, shift and multiply with src1
|
2812
|
+
for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
|
2813
|
+
const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
|
2814
|
+
const sycl::half *src1_shifted = src1 + (sb2 * i0);
|
2815
|
+
float *dst_shifted = dst + (sd2 * i0 * sub_batch);
|
2816
|
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
|
2817
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
|
2818
|
+
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
|
2819
|
+
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
|
2820
|
+
queue, sub_batch, 1);
|
2821
|
+
}
|
2822
|
+
} else {
|
2823
|
+
int64_t num_mul_mats = batches0;
|
2824
|
+
int64_t sub_batch = batches1 / num_mul_mats;
|
2825
|
+
// src1 is batched and bigger, shift and multiply with src0
|
2826
|
+
for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
|
2827
|
+
const sycl::half *src0_shifted = src0 + (sa2 * i1);
|
2828
|
+
const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
|
2829
|
+
float *dst_shifted = dst + (sd2 * i1 * sub_batch);
|
2830
|
+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
|
2831
|
+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
|
2832
|
+
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
|
2833
|
+
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
|
2834
|
+
queue, 1, sub_batch);
|
2835
|
+
}
|
2836
|
+
}
|
2911
2837
|
}
|
2912
|
-
}
|
2913
|
-
|
2914
|
-
|
2915
|
-
|
2916
|
-
|
2917
|
-
|
2918
|
-
|
2919
|
-
|
2920
|
-
|
2838
|
+
};
|
2839
|
+
|
2840
|
+
const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
|
2841
|
+
const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
|
2842
|
+
const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
|
2843
|
+
const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
|
2844
|
+
if (cont_batches_dim2_a && cont_batches_dim2_b) {
|
2845
|
+
// A batch is considered contiguous if the dimension 2 is not strided
|
2846
|
+
int64_t batches0 = ne02 * ne03;
|
2847
|
+
int64_t batches1 = ne12 * ne13;
|
2848
|
+
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
|
2849
|
+
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
|
2850
|
+
str_b2, nb2 / sizeof(float));
|
2851
|
+
} else if (cont_batches_dim3_a && cont_batches_dim3_b) {
|
2852
|
+
// This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
|
2853
|
+
int64_t batches0 = ne02 * ne03;
|
2854
|
+
int64_t batches1 = ne12 * ne13;
|
2855
|
+
int64_t str_a3 = nb03 / type_size_src0;
|
2856
|
+
int64_t str_b3 = nb13 / type_size_src1;
|
2857
|
+
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
|
2858
|
+
ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
|
2859
|
+
str_b3, nb2 / sizeof(float));
|
2860
|
+
} else {
|
2861
|
+
for (int64_t b_a = 0; b_a < ne03; b_a++) {
|
2862
|
+
const sycl::half *src0_f16_shifted
|
2863
|
+
= src0_f16 + (nb03 * b_a / type_size_src0);
|
2864
|
+
const sycl::half *src1_f16_shifted
|
2865
|
+
= src1_f16 + (nb13 * b_a / type_size_src1);
|
2866
|
+
float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
|
2867
|
+
int64_t batches0 = ne02;
|
2868
|
+
int64_t batches1 = ne12;
|
2869
|
+
launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
|
2870
|
+
ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
|
2871
|
+
str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
|
2921
2872
|
}
|
2922
2873
|
}
|
2923
|
-
|
2874
|
+
|
2924
2875
|
}
|
2925
2876
|
else
|
2926
2877
|
#endif
|
2927
2878
|
{
|
2928
|
-
if (r2 == 1 && r3 == 1 &&
|
2879
|
+
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
|
2880
|
+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
|
2881
|
+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
|
2882
|
+
const int64_t smb = ne12 == 1 ? s13 : s12;
|
2883
|
+
|
2929
2884
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
2930
2885
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
2931
2886
|
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
2932
|
-
src0_f16, dpct::library_data_t::real_half, nb01 / nb00,
|
2933
|
-
src1_f16, dpct::library_data_t::real_half, s11,
|
2887
|
+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
|
2888
|
+
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
|
2934
2889
|
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
2935
2890
|
} else {
|
2936
2891
|
const int ne23 = ne12 * ne13;
|
@@ -2945,7 +2900,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
2945
2900
|
void ** ptrs_dst_get = ptrs_dst.get();
|
2946
2901
|
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
2947
2902
|
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
2948
|
-
|
2903
|
+
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
2949
2904
|
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
2950
2905
|
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
2951
2906
|
});
|
@@ -3260,26 +3215,27 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
3260
3215
|
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
3261
3216
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
3262
3217
|
}
|
3263
|
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) &&
|
3218
|
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
|
3264
3219
|
// KQV single-batch
|
3265
3220
|
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
3266
|
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
3221
|
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
|
3267
3222
|
// KQ + KQV multi-batch
|
3268
3223
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
3269
3224
|
} else if (use_dequantize_mul_mat_vec) {
|
3270
|
-
constexpr bool convert_src1_to_q8_1 = false;
|
3271
3225
|
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
|
3272
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec
|
3226
|
+
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
|
3273
3227
|
} else if (use_mul_mat_vec_q) {
|
3274
|
-
constexpr bool convert_src1_to_q8_1 = true;
|
3275
3228
|
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
|
3276
|
-
|
3229
|
+
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
|
3230
|
+
if (extra && extra->optimized_feature.reorder) {
|
3231
|
+
ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
|
3232
|
+
} else {
|
3233
|
+
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
|
3234
|
+
}
|
3277
3235
|
} else if (use_mul_mat_q) {
|
3278
|
-
|
3279
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
3236
|
+
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
|
3280
3237
|
} else {
|
3281
|
-
|
3282
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
3238
|
+
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
|
3283
3239
|
}
|
3284
3240
|
}
|
3285
3241
|
|
@@ -3446,10 +3402,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|
3446
3402
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
3447
3403
|
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
|
3448
3404
|
|
3405
|
+
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
3406
|
+
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
3407
|
+
|
3449
3408
|
{
|
3450
|
-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10,
|
3409
|
+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
|
3451
3410
|
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
|
3452
|
-
|
3411
|
+
stream->submit([&](sycl::handler &cgh) {
|
3453
3412
|
sycl::local_accessor<int, 0> src1_row_acc(cgh);
|
3454
3413
|
|
3455
3414
|
char *__restrict src1_contiguous_get =
|
@@ -3461,8 +3420,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|
3461
3420
|
size_t ids_nb_ct6 = ids->nb[1];
|
3462
3421
|
size_t ids_nb_ct7 = ids->nb[0];
|
3463
3422
|
|
3464
|
-
|
3465
|
-
|
3423
|
+
cgh.parallel_for(
|
3424
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
3425
|
+
[=](sycl::nd_item<3> item_ct1) {
|
3466
3426
|
k_copy_src1_to_contiguous(
|
3467
3427
|
src1_original, src1_contiguous_get,
|
3468
3428
|
dev_cur_src1_row_get,
|
@@ -3491,16 +3451,17 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|
3491
3451
|
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
3492
3452
|
|
3493
3453
|
{
|
3494
|
-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0,
|
3454
|
+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
|
3495
3455
|
sycl::range<3> grid_dims(1, 1, num_src1_rows);
|
3496
|
-
|
3456
|
+
stream->submit([&](sycl::handler &cgh) {
|
3497
3457
|
const char *__restrict dst_contiguous_get =
|
3498
3458
|
dst_contiguous.get();
|
3499
3459
|
const mmid_row_mapping *__restrict dev_row_mapping_get =
|
3500
3460
|
dev_row_mapping.get();
|
3501
3461
|
|
3502
|
-
|
3503
|
-
|
3462
|
+
cgh.parallel_for(
|
3463
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
3464
|
+
[=](sycl::nd_item<3> item_ct1) {
|
3504
3465
|
k_copy_dst_from_contiguous(dst_original,
|
3505
3466
|
dst_contiguous_get,
|
3506
3467
|
dev_row_mapping_get,
|
@@ -3603,6 +3564,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
3603
3564
|
case GGML_OP_GET_ROWS:
|
3604
3565
|
ggml_sycl_get_rows(ctx, dst);
|
3605
3566
|
break;
|
3567
|
+
case GGML_OP_SET_ROWS:
|
3568
|
+
ggml_sycl_op_set_rows(ctx, dst);
|
3569
|
+
break;
|
3606
3570
|
case GGML_OP_DUP:
|
3607
3571
|
ggml_sycl_dup(ctx, dst);
|
3608
3572
|
break;
|
@@ -3613,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
3613
3577
|
case GGML_OP_SUB:
|
3614
3578
|
ggml_sycl_sub(ctx, dst);
|
3615
3579
|
break;
|
3580
|
+
case GGML_OP_COUNT_EQUAL:
|
3581
|
+
ggml_sycl_count_equal(ctx, dst);
|
3582
|
+
break;
|
3616
3583
|
case GGML_OP_ACC:
|
3617
3584
|
ggml_sycl_acc(ctx, dst);
|
3618
3585
|
break;
|
@@ -3687,6 +3654,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
3687
3654
|
case GGML_GLU_OP_SWIGLU:
|
3688
3655
|
ggml_sycl_swiglu(ctx, dst);
|
3689
3656
|
break;
|
3657
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
3658
|
+
ggml_sycl_geglu_erf(ctx, dst);
|
3659
|
+
break;
|
3660
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
3661
|
+
ggml_sycl_geglu_quick(ctx, dst);
|
3662
|
+
break;
|
3690
3663
|
default:
|
3691
3664
|
return false;
|
3692
3665
|
}
|
@@ -4100,6 +4073,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
|
|
4100
4073
|
/* .graph_compute = */ ggml_backend_sycl_graph_compute,
|
4101
4074
|
/* .event_record = */ ggml_backend_sycl_event_record,
|
4102
4075
|
/* .event_wait = */ ggml_backend_sycl_event_wait,
|
4076
|
+
/* .graph_optimize = */ NULL,
|
4103
4077
|
};
|
4104
4078
|
|
4105
4079
|
static ggml_guid_t ggml_backend_sycl_guid() {
|
@@ -4232,6 +4206,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4232
4206
|
case GGML_GLU_OP_REGLU:
|
4233
4207
|
case GGML_GLU_OP_GEGLU:
|
4234
4208
|
case GGML_GLU_OP_SWIGLU:
|
4209
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
4210
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
4235
4211
|
return ggml_is_contiguous_1(op->src[0]);
|
4236
4212
|
default:
|
4237
4213
|
return false;
|
@@ -4240,15 +4216,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4240
4216
|
case GGML_OP_MUL_MAT:
|
4241
4217
|
case GGML_OP_MUL_MAT_ID:
|
4242
4218
|
{
|
4243
|
-
struct ggml_tensor * a;
|
4244
|
-
struct ggml_tensor * b;
|
4245
|
-
|
4246
|
-
a = op->src[0];
|
4247
|
-
b = op->src[1];
|
4248
|
-
} else {
|
4249
|
-
a = op->src[2];
|
4250
|
-
b = op->src[1];
|
4251
|
-
}
|
4219
|
+
struct ggml_tensor * a = op->src[0];
|
4220
|
+
struct ggml_tensor * b = op->src[1];
|
4221
|
+
|
4252
4222
|
if (a->ne[3] != b->ne[3]) {
|
4253
4223
|
return false;
|
4254
4224
|
}
|
@@ -4263,7 +4233,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4263
4233
|
}
|
4264
4234
|
}
|
4265
4235
|
ggml_type src0_type = op->src[0]->type;
|
4266
|
-
if (src0_type == GGML_TYPE_BF16) {
|
4236
|
+
if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
|
4237
|
+
// TODO: support MXFP4
|
4238
|
+
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
|
4239
|
+
return false;
|
4240
|
+
}
|
4241
|
+
// TODO: The configuration below needs more work to be supported with oneDNN
|
4242
|
+
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
|
4243
|
+
return false;
|
4244
|
+
}
|
4245
|
+
// TODO: This specific configuration can fail with oneDNN and needs more debugging
|
4246
|
+
if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
|
4247
|
+
a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
|
4267
4248
|
return false;
|
4268
4249
|
}
|
4269
4250
|
return true;
|
@@ -4285,6 +4266,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4285
4266
|
return false;
|
4286
4267
|
}
|
4287
4268
|
}
|
4269
|
+
case GGML_OP_SET_ROWS:
|
4270
|
+
{
|
4271
|
+
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
4272
|
+
op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
|
4273
|
+
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
|
4274
|
+
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
|
4275
|
+
}
|
4276
|
+
break;
|
4288
4277
|
case GGML_OP_CPY:
|
4289
4278
|
{
|
4290
4279
|
ggml_type src0_type = op->src[0]->type;
|
@@ -4370,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4370
4359
|
case GGML_OP_ADD:
|
4371
4360
|
case GGML_OP_ADD1:
|
4372
4361
|
case GGML_OP_SUB:
|
4362
|
+
case GGML_OP_COUNT_EQUAL:
|
4373
4363
|
case GGML_OP_MUL:
|
4374
4364
|
case GGML_OP_DIV:
|
4375
4365
|
case GGML_OP_REPEAT:
|
@@ -4386,29 +4376,44 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
4386
4376
|
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
4387
4377
|
#endif
|
4388
4378
|
case GGML_OP_NORM:
|
4389
|
-
case GGML_OP_RMS_NORM:
|
4390
4379
|
return true;
|
4391
4380
|
case GGML_OP_L2_NORM:
|
4392
4381
|
case GGML_OP_GROUP_NORM:
|
4393
4382
|
return ggml_is_contiguous(op->src[0]);
|
4383
|
+
case GGML_OP_RMS_NORM:
|
4384
|
+
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
4394
4385
|
case GGML_OP_SCALE:
|
4395
4386
|
return true;
|
4396
4387
|
case GGML_OP_CONT:
|
4397
4388
|
return op->src[0]->type != GGML_TYPE_BF16;
|
4398
|
-
case GGML_OP_DIAG_MASK_INF:
|
4399
4389
|
case GGML_OP_SOFT_MAX:
|
4400
|
-
|
4390
|
+
// TODO: support batching
|
4391
|
+
if (op->src[0]->ne[3] != 1) {
|
4392
|
+
return false;
|
4393
|
+
}
|
4394
|
+
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
4395
|
+
if (op->src[2]) {
|
4396
|
+
return false;
|
4397
|
+
}
|
4398
|
+
// TODO: support broadcast
|
4399
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
4400
|
+
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
4401
|
+
case GGML_OP_DIAG_MASK_INF:
|
4401
4402
|
case GGML_OP_ROPE:
|
4402
4403
|
case GGML_OP_IM2COL:
|
4403
4404
|
return true;
|
4404
4405
|
case GGML_OP_UPSCALE:
|
4405
4406
|
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
4406
|
-
case GGML_OP_POOL_2D:
|
4407
4407
|
case GGML_OP_SUM:
|
4408
4408
|
case GGML_OP_SUM_ROWS:
|
4409
4409
|
case GGML_OP_ARGSORT:
|
4410
|
+
return ggml_is_contiguous(op->src[0]);
|
4411
|
+
case GGML_OP_POOL_2D:
|
4410
4412
|
case GGML_OP_ACC:
|
4413
|
+
return true;
|
4411
4414
|
case GGML_OP_PAD:
|
4415
|
+
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
4416
|
+
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
4412
4417
|
case GGML_OP_LEAKY_RELU:
|
4413
4418
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
4414
4419
|
case GGML_OP_RWKV_WKV6:
|
@@ -4619,10 +4624,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) {
|
|
4619
4624
|
};
|
4620
4625
|
|
4621
4626
|
ggml_backend_t sycl_backend = new ggml_backend {
|
4622
|
-
/* .guid
|
4623
|
-
/* .
|
4624
|
-
/* .device
|
4625
|
-
/* .context
|
4627
|
+
/* .guid = */ ggml_backend_sycl_guid(),
|
4628
|
+
/* .iface = */ ggml_backend_sycl_interface,
|
4629
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
|
4630
|
+
/* .context = */ ctx
|
4626
4631
|
};
|
4627
4632
|
|
4628
4633
|
return sycl_backend;
|