whispercpp 1.3.2 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +59 -27
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +154 -35
- data/ext/sources/examples/addon.node/index.js +10 -5
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +29 -18
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +7 -4
- data/ext/sources/examples/command/command.cpp +58 -32
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +21 -17
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +193 -35
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +10 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
- data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
- data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
- data/ext/sources/examples/talk-llama/llama-context.h +68 -32
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
- data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
- data/ext/sources/examples/talk-llama/llama-model.h +87 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
- data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -17
- data/ext/sources/examples/talk-llama/llama.h +176 -151
- data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +106 -33
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +18 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +365 -21
- data/ext/sources/ggml/src/CMakeLists.txt +98 -25
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
- data/ext/sources/ggml/src/ggml-common.h +21 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
- data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
- data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
- data/ext/sources/ggml/src/ggml-impl.h +229 -175
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +117 -24
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +802 -142
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +32 -4
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +241 -215
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +57 -2
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +75 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/{tests → test}/test_params.rb +8 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +246 -191
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -26,7 +26,7 @@ static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_
|
|
26
26
|
|
27
27
|
// make each work-item deal with more elements since sycl global range can not exceed max int
|
28
28
|
for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) {
|
29
|
-
const int64_t ksize = OW *
|
29
|
+
const int64_t ksize = OW * KH;
|
30
30
|
const int64_t kx = i / ksize;
|
31
31
|
const int64_t kd = kx * ksize;
|
32
32
|
const int64_t ky = (i - kd) / OW;
|
@@ -29,24 +29,23 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
|
29
29
|
static_assert(blocks_per_subgroup > 0);
|
30
30
|
static_assert(block_elements_per_subgroup > 0);
|
31
31
|
|
32
|
-
const block_q8_1 * y = (const block_q8_1 *) vy;
|
33
|
-
|
34
32
|
float partial_sum = 0.0f;
|
35
33
|
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
36
|
-
const int ibx
|
37
|
-
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
38
|
-
const int bx_offset = block_type::get_block_offset(ibx);
|
39
|
-
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
34
|
+
const int ibx = row * blocks_per_row + i; // x block index
|
40
35
|
|
36
|
+
const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
|
37
|
+
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
41
38
|
// Y block index that aligns with ibx
|
42
39
|
const int iby = i * block_type::block_to_q8_1_ratio();
|
40
|
+
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
|
41
|
+
const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));
|
43
42
|
|
44
43
|
#pragma unroll
|
45
44
|
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
|
46
45
|
// x block quant index when casting the quants to int
|
47
46
|
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
48
47
|
|
49
|
-
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset,
|
48
|
+
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
|
50
49
|
}
|
51
50
|
}
|
52
51
|
|
@@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|
785
784
|
}
|
786
785
|
}
|
787
786
|
|
787
|
+
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
788
|
+
const int nrows, dpct::queue_ptr stream) {
|
789
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
790
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
791
|
+
constexpr size_t num_subgroups = 16;
|
792
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
793
|
+
|
794
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
795
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
796
|
+
|
797
|
+
stream->submit([&](sycl::handler & cgh) {
|
798
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
799
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
800
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
801
|
+
nd_item);
|
802
|
+
});
|
803
|
+
});
|
804
|
+
}
|
788
805
|
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
789
806
|
float *dst, const int ncols,
|
790
807
|
const int nrows,
|
@@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
1070
1087
|
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
1071
1088
|
break;
|
1072
1089
|
case GGML_TYPE_Q6_K:
|
1073
|
-
|
1090
|
+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
1091
|
+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
1092
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
|
1093
|
+
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
1094
|
+
} else {
|
1095
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
|
1096
|
+
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
1097
|
+
}
|
1074
1098
|
break;
|
1075
1099
|
case GGML_TYPE_IQ1_S:
|
1076
1100
|
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
@@ -0,0 +1,133 @@
|
|
1
|
+
/***************************************************************************
|
2
|
+
*
|
3
|
+
* Copyright (C) 2025 Codeplay Software Ltd.
|
4
|
+
* Copyright (C) 2025 Intel Corporation
|
5
|
+
*
|
6
|
+
* MIT License
|
7
|
+
*
|
8
|
+
* Unless required by applicable law or agreed to in writing, software
|
9
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
* See the License for the specific language governing permissions and
|
12
|
+
* limitations under the License.
|
13
|
+
*
|
14
|
+
* quantize.hpp
|
15
|
+
*
|
16
|
+
* Description:
|
17
|
+
* Sycl backend specific quantization functions
|
18
|
+
**************************************************************************/
|
19
|
+
|
20
|
+
#pragma once
|
21
|
+
|
22
|
+
#include <sycl/nd_item.hpp>
|
23
|
+
|
24
|
+
#include "ggml-sycl/dpct/helper.hpp"
|
25
|
+
|
26
|
+
template <int ElementsPerWI>
|
27
|
+
__dpct_inline__ static void quantize_q8_1_impl(const float * __restrict__ x,
|
28
|
+
sycl::vec<int8_t, ElementsPerWI> & quantized_values, float & d,
|
29
|
+
float & sum, const sycl::nd_item<1> & it) {
|
30
|
+
auto subgroup_id = it.get_group(0);
|
31
|
+
auto wi_id = it.get_local_id(0);
|
32
|
+
|
33
|
+
sycl::vec<float, ElementsPerWI> wi_f32_vals;
|
34
|
+
|
35
|
+
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
|
36
|
+
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
|
37
|
+
|
38
|
+
float amax = 0.0f;
|
39
|
+
|
40
|
+
#pragma unroll(ElementsPerWI)
|
41
|
+
for (int i = 0; i < ElementsPerWI; i++) {
|
42
|
+
sum += wi_f32_vals[i];
|
43
|
+
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
|
44
|
+
quantized_values[i] = 0;
|
45
|
+
}
|
46
|
+
sum = sycl::reduce_over_group(it.get_sub_group(), sum, sycl::plus<float>());
|
47
|
+
amax = sycl::reduce_over_group(it.get_sub_group(), amax, sycl::maximum<float>());
|
48
|
+
d = amax == 0 ? 1 : amax / 127;
|
49
|
+
|
50
|
+
#pragma unroll(ElementsPerWI)
|
51
|
+
for (int i = 0; i < ElementsPerWI; i++) {
|
52
|
+
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
|
53
|
+
}
|
54
|
+
|
55
|
+
d = amax == 0 ? 0 : d;
|
56
|
+
}
|
57
|
+
|
58
|
+
// No op to control codepath in ggml_sycl_op_mul_mat
|
59
|
+
template <int ElementsPerWI> struct no_quantize_q8_1 {
|
60
|
+
void operator()(const float *, void *, int, int, const sycl::nd_item<1> &) const {}
|
61
|
+
};
|
62
|
+
|
63
|
+
template <int ElementsPerWI> struct quantize_and_reorder_q8_1_soa {
|
64
|
+
__dpct_inline__ void operator()(const float * __restrict__ x, void * reordered_q8_tensor, const int kx,
|
65
|
+
const int kx_padded, const sycl::nd_item<1> & it) const {
|
66
|
+
/*
|
67
|
+
Quantizes and reorders the resultant q8 tensor in a per row fashion
|
68
|
+
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
|
69
|
+
*/
|
70
|
+
auto subgroup_id = it.get_group(0);
|
71
|
+
auto wi_id = it.get_local_id(0);
|
72
|
+
|
73
|
+
sycl::vec<int8_t, ElementsPerWI> quantized_values;
|
74
|
+
float d = 0.0f;
|
75
|
+
float sum = 0.0f;
|
76
|
+
quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);
|
77
|
+
|
78
|
+
const int num_blocks_per_row = kx / QK8_1;
|
79
|
+
auto row = subgroup_id / num_blocks_per_row;
|
80
|
+
auto col = subgroup_id % num_blocks_per_row;
|
81
|
+
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
|
82
|
+
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
|
83
|
+
|
84
|
+
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
|
85
|
+
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
|
86
|
+
|
87
|
+
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
|
88
|
+
if (wi_id == 0) {
|
89
|
+
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
|
90
|
+
}
|
91
|
+
}
|
92
|
+
};
|
93
|
+
|
94
|
+
template <int ElementsPerWI> struct quantize_q8_1 {
|
95
|
+
__dpct_inline__ void operator()(const float * __restrict__ x, void * q8_tensor, const int kx, const int kx_padded,
|
96
|
+
const sycl::nd_item<1> & it) const {
|
97
|
+
auto subgroup_id = it.get_group(0);
|
98
|
+
auto wi_id = it.get_local_id(0);
|
99
|
+
|
100
|
+
const int num_blocks_per_row = kx / QK8_1;
|
101
|
+
auto row = subgroup_id / num_blocks_per_row;
|
102
|
+
const int pitch = kx_padded / QK8_1;
|
103
|
+
|
104
|
+
sycl::vec<int8_t, ElementsPerWI> quantized_values;
|
105
|
+
float d = 0.0f;
|
106
|
+
float sum = 0.0f;
|
107
|
+
quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);
|
108
|
+
|
109
|
+
block_q8_1 * quant_ptr = (block_q8_1 *) q8_tensor;
|
110
|
+
auto block_id = subgroup_id % num_blocks_per_row + row * pitch;
|
111
|
+
|
112
|
+
int8_t * qs = &(quant_ptr[block_id].qs[wi_id * ElementsPerWI]);
|
113
|
+
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(qs) = quantized_values;
|
114
|
+
if (wi_id == 0) {
|
115
|
+
quant_ptr[block_id].ds = sycl::half2(sycl::half(d), sycl::half(sum));
|
116
|
+
}
|
117
|
+
}
|
118
|
+
};
|
119
|
+
|
120
|
+
template <template <int> typename quantize_f>
|
121
|
+
void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
|
122
|
+
dpct::queue_ptr stream) {
|
123
|
+
static_assert(QK8_1 % WARP_SIZE == 0);
|
124
|
+
auto local_range = std::size_t(WARP_SIZE);
|
125
|
+
auto num_quant_blocks = ky * (kx / QK8_1);
|
126
|
+
auto global_range = num_quant_blocks * local_range;
|
127
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
128
|
+
|
129
|
+
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
|
130
|
+
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
131
|
+
quantize_f<QK8_1 / WARP_SIZE>()(x, vy, kx, kx_padded, it);
|
132
|
+
});
|
133
|
+
}
|
@@ -14,12 +14,13 @@
|
|
14
14
|
#ifndef GGML_SYCL_QUANTS_HPP
|
15
15
|
#define GGML_SYCL_QUANTS_HPP
|
16
16
|
|
17
|
+
#include <utility>
|
18
|
+
|
17
19
|
#include "ggml-common.h"
|
18
20
|
#include "ggml.h"
|
19
21
|
|
20
22
|
namespace ggml_sycl_reordered {
|
21
23
|
|
22
|
-
|
23
24
|
// The reordered block moves quants (qs) and scales(d) to two
|
24
25
|
// uniform regions of memory that is contiguous in the same tensor.
|
25
26
|
// What this means is that instead of having:
|
@@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
|
|
32
33
|
|
33
34
|
template <ggml_type type> struct block_q_t;
|
34
35
|
|
35
|
-
|
36
36
|
// qk number of weights / quants in a block
|
37
37
|
// qr number of weights in a byte (described as 'before dequantization')
|
38
38
|
// for quantization types that has low and high bits split, qr is calculated with
|
@@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
|
47
47
|
static constexpr uint32_t vdr_mmvq = 2;
|
48
48
|
};
|
49
49
|
|
50
|
-
static constexpr int get_block_offset(const int block_index
|
50
|
+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
51
|
+
return { block_index * (QK4_0 / QR4_0), 0 };
|
52
|
+
}
|
51
53
|
|
52
|
-
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
53
|
-
return (ncols /
|
54
|
+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
55
|
+
return { (ncols / QR4_0 * nrows) + block_index * sizeof(ggml_half), 0 };
|
54
56
|
}
|
55
57
|
|
56
58
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
@@ -64,18 +66,43 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
|
|
64
66
|
static constexpr uint32_t vdr_mmvq = 2;
|
65
67
|
};
|
66
68
|
|
67
|
-
static constexpr int get_block_offset(const int block_index
|
69
|
+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
70
|
+
return { block_index * (traits::qk / traits::qr), 0 };
|
71
|
+
}
|
68
72
|
|
69
|
-
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
70
|
-
auto nblocks = (nrows * (ncols /
|
71
|
-
return
|
73
|
+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
74
|
+
auto nblocks = (nrows * (ncols / QK_K));
|
75
|
+
return { nblocks * (QK_K / 2) + (block_index * K_SCALE_SIZE),
|
76
|
+
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
|
72
77
|
}
|
73
78
|
|
74
79
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
80
|
+
};
|
81
|
+
|
82
|
+
template <> struct block_q_t<GGML_TYPE_Q6_K> {
|
83
|
+
struct traits {
|
84
|
+
static constexpr uint32_t qk = QK_K;
|
85
|
+
static constexpr uint32_t qi = QI6_K;
|
86
|
+
static constexpr uint32_t qr = QR6_K;
|
87
|
+
static constexpr uint32_t vdr_mmvq = 1;
|
88
|
+
};
|
75
89
|
|
76
|
-
constexpr
|
90
|
+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
|
91
|
+
auto low_bits_index = block_index * (QK_K / QR6_K);
|
92
|
+
// the index of high bits it's after all low bits
|
93
|
+
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
|
94
|
+
return { low_bits_index, high_bits_index };
|
95
|
+
}
|
77
96
|
|
78
|
-
constexpr
|
97
|
+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
98
|
+
auto nblocks = (nrows * (ncols / QK_K));
|
99
|
+
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
|
100
|
+
auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
|
101
|
+
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16) + block_index * sizeof(ggml_half);
|
102
|
+
return { block_scales, sb_scale };
|
103
|
+
}
|
104
|
+
|
105
|
+
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
79
106
|
};
|
80
107
|
|
81
108
|
} // namespace ggml_sycl_reordered
|
@@ -47,21 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
|
47
47
|
|
48
48
|
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
49
49
|
|
50
|
-
if (i0 >= n_dims) {
|
51
|
-
const int i = row * ne0 + i0;
|
52
|
-
|
53
|
-
dst[i + 0] = x[i + 0];
|
54
|
-
dst[i + 1] = x[i + 1];
|
55
|
-
|
56
|
-
return;
|
57
|
-
}
|
58
|
-
|
59
50
|
const int row0 = row % ne1;
|
60
51
|
const int channel0 = row / ne1;
|
61
52
|
|
62
53
|
const int i = row * ne0 + i0;
|
63
54
|
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
64
55
|
|
56
|
+
if (i0 >= n_dims) {
|
57
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
|
58
|
+
return;
|
59
|
+
}
|
60
|
+
|
65
61
|
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
66
62
|
|
67
63
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
@@ -91,21 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|
91
87
|
|
92
88
|
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
93
89
|
|
94
|
-
if (i0 >= n_dims) {
|
95
|
-
const int i = row * ne0 + i0;
|
96
|
-
|
97
|
-
dst[i + 0] = x[i + 0];
|
98
|
-
dst[i + 1] = x[i + 1];
|
99
|
-
|
100
|
-
return;
|
101
|
-
}
|
102
|
-
|
103
90
|
const int row0 = row % ne1;
|
104
91
|
const int channel0 = row / ne1;
|
105
92
|
|
106
93
|
const int i = row * ne0 + i0 / 2;
|
107
94
|
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
108
95
|
|
96
|
+
if (i0 >= n_dims) {
|
97
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
|
98
|
+
return;
|
99
|
+
}
|
100
|
+
|
109
101
|
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
110
102
|
|
111
103
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
@@ -122,6 +114,62 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|
122
114
|
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
123
115
|
}
|
124
116
|
|
117
|
+
template <typename T, bool has_ff>
|
118
|
+
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
119
|
+
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
120
|
+
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
121
|
+
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
122
|
+
const sycl::nd_item<3> & item_ct1) {
|
123
|
+
// get index pos
|
124
|
+
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
125
|
+
if (i0 >= ne0) {
|
126
|
+
return;
|
127
|
+
}
|
128
|
+
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
129
|
+
|
130
|
+
const int row_x = row_dst % ne1;
|
131
|
+
const int channel_x = row_dst / ne1;
|
132
|
+
const int idst = (row_dst * ne0) + (i0 / 2);
|
133
|
+
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
134
|
+
|
135
|
+
if (i0 >= n_dims) {
|
136
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
|
137
|
+
return;
|
138
|
+
}
|
139
|
+
|
140
|
+
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
141
|
+
const int sec_w = sections.v[1] + sections.v[0];
|
142
|
+
const int sector = (i0 / 2) % sect_dims;
|
143
|
+
|
144
|
+
|
145
|
+
float theta_base = 0.0;
|
146
|
+
if (sector < sections.v[0]) {
|
147
|
+
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
148
|
+
}
|
149
|
+
else if (sector >= sections.v[0] && sector < sec_w) {
|
150
|
+
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
151
|
+
}
|
152
|
+
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
153
|
+
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
154
|
+
}
|
155
|
+
else if (sector >= sec_w + sections.v[2]) {
|
156
|
+
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
157
|
+
}
|
158
|
+
|
159
|
+
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
160
|
+
float cos_theta;
|
161
|
+
float sin_theta;
|
162
|
+
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
163
|
+
const float x0 = x[ix + 0];
|
164
|
+
const float x1 = x[ix + n_dims/2];
|
165
|
+
|
166
|
+
// store results in dst
|
167
|
+
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
168
|
+
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
169
|
+
}
|
170
|
+
|
171
|
+
|
172
|
+
|
125
173
|
template <typename T, bool has_ff>
|
126
174
|
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
127
175
|
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
@@ -171,7 +219,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
171
219
|
const float * freq_factors, queue_ptr stream) {
|
172
220
|
GGML_ASSERT(ne0 % 2 == 0);
|
173
221
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
174
|
-
const int num_blocks_x = (ne0
|
222
|
+
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
175
223
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
176
224
|
|
177
225
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
@@ -208,7 +256,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
208
256
|
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
209
257
|
GGML_ASSERT(ne0 % 2 == 0);
|
210
258
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
211
|
-
const int num_blocks_x = (ne0
|
259
|
+
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
212
260
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
213
261
|
|
214
262
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
@@ -228,6 +276,40 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
228
276
|
}
|
229
277
|
}
|
230
278
|
|
279
|
+
template <typename T>
|
280
|
+
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
281
|
+
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
282
|
+
const float freq_scale, const float freq_base, const float ext_factor,
|
283
|
+
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
284
|
+
const mrope_sections sections, queue_ptr stream) {
|
285
|
+
GGML_ASSERT(ne0 % 2 == 0);
|
286
|
+
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
287
|
+
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
288
|
+
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
289
|
+
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
290
|
+
|
291
|
+
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
292
|
+
// Add FP16 capability check if T could be sycl::half
|
293
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
294
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
295
|
+
}
|
296
|
+
// launch kernel
|
297
|
+
if (freq_factors == nullptr) {
|
298
|
+
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
299
|
+
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
300
|
+
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
301
|
+
});
|
302
|
+
} else {
|
303
|
+
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
304
|
+
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
305
|
+
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
306
|
+
});
|
307
|
+
}
|
308
|
+
}
|
309
|
+
|
310
|
+
|
311
|
+
|
312
|
+
|
231
313
|
// rope vision
|
232
314
|
template <typename T>
|
233
315
|
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
@@ -237,7 +319,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
|
237
319
|
const mrope_sections sections, queue_ptr stream) {
|
238
320
|
GGML_ASSERT(ne0 % 2 == 0);
|
239
321
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
240
|
-
const int n_blocks_y = (ne0
|
322
|
+
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
241
323
|
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
242
324
|
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
243
325
|
|
@@ -298,8 +380,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
298
380
|
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
299
381
|
|
300
382
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
383
|
+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
301
384
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
302
385
|
|
386
|
+
if (is_mrope) {
|
387
|
+
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
|
388
|
+
}
|
389
|
+
|
390
|
+
if (is_vision) {
|
391
|
+
GGML_ASSERT(n_dims == ne00/2);
|
392
|
+
}
|
393
|
+
|
303
394
|
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
304
395
|
|
305
396
|
const float * freq_factors = nullptr;
|
@@ -326,6 +417,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
326
417
|
} else {
|
327
418
|
GGML_ABORT("fatal error");
|
328
419
|
}
|
420
|
+
} else if (is_mrope && !is_vision) {
|
421
|
+
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
|
422
|
+
if (dst->src[0]->type == GGML_TYPE_F16) {
|
423
|
+
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
424
|
+
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
425
|
+
freq_factors, sections, main_stream);
|
426
|
+
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
427
|
+
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
428
|
+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
429
|
+
main_stream);
|
430
|
+
} else {
|
431
|
+
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
432
|
+
}
|
329
433
|
} else if (is_vision) {
|
330
434
|
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
331
435
|
if (dst->src[0]->type == GGML_TYPE_F16) {
|