whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -1,8 +1,20 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include "common.cuh"
|
4
|
+
|
4
5
|
#include <cstdint>
|
5
6
|
|
7
|
+
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
|
8
|
+
const uint8_t * x8 = (const uint8_t *) x;
|
9
|
+
|
10
|
+
int x32 = x8[4*i32 + 0] << 0;
|
11
|
+
x32 |= x8[4*i32 + 1] << 8;
|
12
|
+
x32 |= x8[4*i32 + 2] << 16;
|
13
|
+
x32 |= x8[4*i32 + 3] << 24;
|
14
|
+
|
15
|
+
return x32;
|
16
|
+
}
|
17
|
+
|
6
18
|
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
7
19
|
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
8
20
|
|
@@ -16,6 +28,72 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
|
|
16
28
|
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
17
29
|
}
|
18
30
|
|
31
|
+
// q4 contains 8 indices with 4 bit each.
|
32
|
+
// This function selects those bytes from table that are at those indices and returns them as int2.
|
33
|
+
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
|
34
|
+
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
|
35
|
+
#if defined(GGML_USE_HIP)
|
36
|
+
// Load the 16-byte table into four 32-bit unsigned integers.
|
37
|
+
const uint32_t *values = (const uint32_t *)table;
|
38
|
+
|
39
|
+
const uint32_t q_even = q4;
|
40
|
+
const uint32_t q_odd = (q4 >> 4);
|
41
|
+
|
42
|
+
// Perform lookups in the lower half of the table (indices 0-7).
|
43
|
+
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
|
44
|
+
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
|
45
|
+
|
46
|
+
// Perform lookups in the upper half of the table (indices 8-15).
|
47
|
+
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
|
48
|
+
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
|
49
|
+
|
50
|
+
// Select between the low and high results based on the MSB of each index nibble.
|
51
|
+
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
|
52
|
+
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
|
53
|
+
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
|
54
|
+
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
|
55
|
+
|
56
|
+
return make_int2(res_x, res_y);
|
57
|
+
#elif !defined(GGML_USE_MUSA)
|
58
|
+
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
|
59
|
+
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
|
60
|
+
const uint32_t * table32 = (const uint32_t *) table;
|
61
|
+
|
62
|
+
// __byte_perm selects bytes based on the lower 16 bits in its third argument.
|
63
|
+
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
|
64
|
+
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
|
65
|
+
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
|
66
|
+
uint32_t tmp[2];
|
67
|
+
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
|
68
|
+
#pragma unroll
|
69
|
+
for (uint32_t i = 0; i < 2; ++i) {
|
70
|
+
const uint32_t shift = 16 * i;
|
71
|
+
|
72
|
+
const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
|
73
|
+
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
|
74
|
+
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
|
75
|
+
}
|
76
|
+
|
77
|
+
// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
|
78
|
+
// However, for the result we need ints with all even/odd 4 bit indices in q4.
|
79
|
+
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
|
80
|
+
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
|
81
|
+
#else
|
82
|
+
// Generic implementation.
|
83
|
+
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
84
|
+
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
85
|
+
const char4 val0_8 = make_char4(
|
86
|
+
table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
|
87
|
+
|
88
|
+
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
89
|
+
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
90
|
+
const char4 val1_8 = make_char4(
|
91
|
+
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
|
92
|
+
|
93
|
+
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
94
|
+
#endif
|
95
|
+
}
|
96
|
+
|
19
97
|
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
20
98
|
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
21
99
|
|
@@ -61,7 +139,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
|
|
61
139
|
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
|
62
140
|
}
|
63
141
|
|
64
|
-
#ifdef
|
142
|
+
#ifdef FAST_FP16_AVAILABLE
|
65
143
|
const float2 tmp = __half22float2(__hmul2(dm4, ds8));
|
66
144
|
const float d4d8 = tmp.x;
|
67
145
|
const float m4s8 = tmp.y;
|
@@ -70,7 +148,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
|
|
70
148
|
const float2 ds8f = __half22float2(ds8);
|
71
149
|
const float d4d8 = dm4f.x * ds8f.x;
|
72
150
|
const float m4s8 = dm4f.y * ds8f.y;
|
73
|
-
#endif //
|
151
|
+
#endif // FAST_FP16_AVAILABLE
|
74
152
|
|
75
153
|
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
|
76
154
|
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
|
@@ -132,7 +210,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
|
|
132
210
|
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
|
133
211
|
}
|
134
212
|
|
135
|
-
#ifdef
|
213
|
+
#ifdef FAST_FP16_AVAILABLE
|
136
214
|
const float2 tmp = __half22float2(__hmul2(dm5, ds8));
|
137
215
|
const float d5d8 = tmp.x;
|
138
216
|
const float m5s8 = tmp.y;
|
@@ -141,7 +219,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
|
|
141
219
|
const float2 ds8f = __half22float2(ds8);
|
142
220
|
const float d5d8 = dm5f.x * ds8f.x;
|
143
221
|
const float m5s8 = dm5f.y * ds8f.y;
|
144
|
-
#endif //
|
222
|
+
#endif // FAST_FP16_AVAILABLE
|
145
223
|
|
146
224
|
// scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
|
147
225
|
return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
|
@@ -175,7 +253,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
|
175
253
|
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
|
176
254
|
}
|
177
255
|
|
178
|
-
#ifdef
|
256
|
+
#ifdef FAST_FP16_AVAILABLE
|
179
257
|
const float2 tmp = __half22float2(__hmul2(dm8, ds8));
|
180
258
|
const float d8d8 = tmp.x;
|
181
259
|
const float m8s8 = tmp.y;
|
@@ -184,7 +262,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
|
184
262
|
const float2 ds8f = __half22float2(ds8);
|
185
263
|
const float d8d8 = dm8f.x * ds8f.x;
|
186
264
|
const float m8s8 = dm8f.y * ds8f.y;
|
187
|
-
#endif //
|
265
|
+
#endif // FAST_FP16_AVAILABLE
|
188
266
|
|
189
267
|
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
|
190
268
|
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
|
@@ -211,6 +289,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
|
|
211
289
|
return d8_1*sumf;
|
212
290
|
}
|
213
291
|
|
292
|
+
#define VDR_MXFP4_Q8_1_MMVQ 2
|
293
|
+
#define VDR_MXFP4_Q8_1_MMQ 4
|
294
|
+
|
295
|
+
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
|
296
|
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
297
|
+
|
298
|
+
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
|
299
|
+
|
300
|
+
const int * q8 = (const int *) bq8_1->qs + iqs;
|
301
|
+
|
302
|
+
int sumi = 0;
|
303
|
+
#pragma unroll
|
304
|
+
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
|
305
|
+
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
|
306
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
307
|
+
|
308
|
+
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
309
|
+
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
310
|
+
}
|
311
|
+
|
312
|
+
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
|
313
|
+
return d * sumi;
|
314
|
+
}
|
315
|
+
|
214
316
|
#define VDR_Q2_K_Q8_1_MMVQ 1
|
215
317
|
#define VDR_Q2_K_Q8_1_MMQ 4
|
216
318
|
|
@@ -1068,20 +1170,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
|
1068
1170
|
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
|
1069
1171
|
}
|
1070
1172
|
|
1071
|
-
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
1072
|
-
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
1073
|
-
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
1074
|
-
const char4 val0_8 = make_char4(
|
1075
|
-
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
|
1076
|
-
|
1077
|
-
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
1078
|
-
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
1079
|
-
const char4 val1_8 = make_char4(
|
1080
|
-
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
|
1081
|
-
|
1082
|
-
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
1083
|
-
}
|
1084
|
-
|
1085
1173
|
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
1086
1174
|
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
1087
1175
|
|
@@ -1096,7 +1184,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
|
1096
1184
|
#pragma unroll
|
1097
1185
|
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
1098
1186
|
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
|
1099
|
-
const int2 v = get_int_from_table_16(aux_q4);
|
1187
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
1100
1188
|
|
1101
1189
|
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
1102
1190
|
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
@@ -1118,7 +1206,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
|
1118
1206
|
#pragma unroll
|
1119
1207
|
for (int j = 0; j < 4; ++j) {
|
1120
1208
|
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
|
1121
|
-
const int2 v = get_int_from_table_16(aux_q4);
|
1209
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
1122
1210
|
|
1123
1211
|
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
|
1124
1212
|
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
|
@@ -6,6 +6,10 @@
|
|
6
6
|
#include <cuda_bf16.h>
|
7
7
|
#include <cuda_fp16.h>
|
8
8
|
|
9
|
+
#if CUDART_VERSION >= 12050
|
10
|
+
#include <cuda_fp8.h>
|
11
|
+
#endif // CUDART_VERSION >= 12050
|
12
|
+
|
9
13
|
#if CUDART_VERSION < 11020
|
10
14
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
11
15
|
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
@@ -1,18 +1,11 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
-
#define
|
3
|
+
#define HIP_DISABLE_WARP_SYNC_BUILTINS 1
|
4
4
|
#include <hip/hip_runtime.h>
|
5
5
|
#include <hipblas/hipblas.h>
|
6
6
|
#include <hip/hip_fp16.h>
|
7
|
-
#include <hip/
|
8
|
-
#ifdef __HIP_PLATFORM_AMD__
|
9
|
-
// for rocblas_initialize()
|
10
|
-
#include "rocblas/rocblas.h"
|
11
|
-
#endif // __HIP_PLATFORM_AMD__
|
7
|
+
#include <hip/hip_bf16.h>
|
12
8
|
|
13
|
-
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
14
|
-
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
15
|
-
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
16
9
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
17
10
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
18
11
|
#define CUBLAS_OP_N HIPBLAS_OP_N
|
@@ -29,8 +22,10 @@
|
|
29
22
|
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
30
23
|
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
31
24
|
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
25
|
+
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
|
32
26
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
33
|
-
#define
|
27
|
+
#define __all_sync(mask, var) __all(var)
|
28
|
+
#define __any_sync(mask, var) __any(var)
|
34
29
|
#define cublasCreate hipblasCreate
|
35
30
|
#define cublasDestroy hipblasDestroy
|
36
31
|
#define cublasGemmEx hipblasGemmEx
|
@@ -42,7 +37,6 @@
|
|
42
37
|
#define cublasSgemm hipblasSgemm
|
43
38
|
#define cublasStatus_t hipblasStatus_t
|
44
39
|
#define cublasOperation_t hipblasOperation_t
|
45
|
-
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
46
40
|
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
47
41
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
48
42
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
@@ -144,24 +138,61 @@
|
|
144
138
|
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
145
139
|
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
146
140
|
|
141
|
+
#if HIP_VERSION >= 60500000
|
142
|
+
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
143
|
+
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
144
|
+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
145
|
+
#define cublasComputeType_t hipblasComputeType_t
|
146
|
+
#define cudaDataType_t hipDataType
|
147
|
+
#else
|
148
|
+
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
149
|
+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
150
|
+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
151
|
+
#define cublasComputeType_t hipblasDatatype_t
|
152
|
+
#define cudaDataType_t hipblasDatatype_t
|
153
|
+
#endif // HIP_VERSION >= 6050000
|
154
|
+
|
155
|
+
#if !defined(__HIP_PLATFORM_AMD__)
|
156
|
+
#error "The HIP backend supports only AMD targets"
|
157
|
+
#endif // !defined(__HIP_PLATFORM_AMD__)
|
158
|
+
|
147
159
|
#define __CUDA_ARCH__ 1300
|
148
160
|
|
149
|
-
#if defined(
|
161
|
+
#if defined(__gfx900__) || defined(__gfx906__)
|
162
|
+
#define GCN5
|
163
|
+
#endif // defined(__gfx900__) || defined(__gfx906__)
|
164
|
+
|
165
|
+
#if defined(__gfx803__)
|
166
|
+
#define GCN4
|
167
|
+
#endif // defined(__gfx803__)
|
168
|
+
|
169
|
+
#if defined(GCN5) || defined(GCN4)
|
150
170
|
#define GCN
|
151
|
-
#endif
|
171
|
+
#endif // defined(GCN5) || defined(GCN4)
|
152
172
|
|
153
|
-
#if defined(
|
154
|
-
#define
|
155
|
-
#endif
|
173
|
+
#if defined(__gfx942__)
|
174
|
+
#define CDNA3
|
175
|
+
#endif // defined(__gfx942__)
|
176
|
+
|
177
|
+
#if defined(__gfx90a__)
|
178
|
+
#define CDNA2
|
179
|
+
#endif // defined(__gfx90a__)
|
180
|
+
|
181
|
+
#if defined(__gfx908__)
|
182
|
+
#define CDNA1
|
183
|
+
#endif // defined(__gfx908__)
|
184
|
+
|
185
|
+
#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
186
|
+
#define CDNA // For the entire family
|
187
|
+
#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
156
188
|
|
157
189
|
#if defined(__GFX12__)
|
158
190
|
#define RDNA4
|
159
|
-
#endif
|
191
|
+
#endif // defined(__GFX12__)
|
160
192
|
|
161
|
-
#if defined(
|
162
|
-
defined(__gfx1150__) || defined(__gfx1151__)
|
193
|
+
#if defined(__GFX11__)
|
163
194
|
#define RDNA3
|
164
|
-
#endif
|
195
|
+
#endif // defined(__GFX11__)
|
165
196
|
|
166
197
|
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
167
198
|
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
@@ -170,13 +201,18 @@
|
|
170
201
|
|
171
202
|
#if defined(__gfx1010__) || defined(__gfx1012__)
|
172
203
|
#define RDNA1
|
173
|
-
#endif
|
204
|
+
#endif // defined(__gfx1010__) || defined(__gfx1012__)
|
205
|
+
|
206
|
+
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
207
|
+
#define RDNA // For the entire family
|
208
|
+
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
174
209
|
|
175
210
|
#ifndef __has_builtin
|
176
211
|
#define __has_builtin(x) 0
|
177
212
|
#endif
|
178
213
|
|
179
|
-
typedef
|
214
|
+
typedef __hip_bfloat16 nv_bfloat16;
|
215
|
+
typedef __hip_bfloat162 nv_bfloat162;
|
180
216
|
|
181
217
|
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
182
218
|
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
@@ -227,17 +263,3 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
|
|
227
263
|
}
|
228
264
|
return c;
|
229
265
|
}
|
230
|
-
|
231
|
-
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
232
|
-
// __shfl_xor() for half2 was added in ROCm 5.6
|
233
|
-
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
|
234
|
-
typedef union half2_b32 {
|
235
|
-
half2 val;
|
236
|
-
int b32;
|
237
|
-
} half2_b32_t;
|
238
|
-
half2_b32_t tmp;
|
239
|
-
tmp.val = var;
|
240
|
-
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
|
241
|
-
return tmp.val;
|
242
|
-
}
|
243
|
-
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
@@ -13,7 +13,7 @@
|
|
13
13
|
#define CUBLAS_OP_N MUBLAS_OP_N
|
14
14
|
#define CUBLAS_OP_T MUBLAS_OP_T
|
15
15
|
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
16
|
-
#define CUBLAS_TF32_TENSOR_OP_MATH
|
16
|
+
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
|
17
17
|
#define CUDA_R_16F MUSA_R_16F
|
18
18
|
#define CUDA_R_16BF MUSA_R_16BF
|
19
19
|
#define CUDA_R_32F MUSA_R_32F
|
@@ -29,7 +29,7 @@
|
|
29
29
|
#define cublasSgemm mublasSgemm
|
30
30
|
#define cublasStatus_t mublasStatus_t
|
31
31
|
#define cublasOperation_t mublasOperation_t
|
32
|
-
#define cublasGetStatusString
|
32
|
+
#define cublasGetStatusString mublasGetStatusString
|
33
33
|
#define cudaDataType_t musaDataType_t
|
34
34
|
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
35
35
|
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
@@ -137,4 +137,5 @@
|
|
137
137
|
#define cudaStreamEndCapture musaStreamEndCapture
|
138
138
|
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
139
139
|
|
140
|
-
typedef
|
140
|
+
typedef __mt_bfloat16 nv_bfloat16;
|
141
|
+
typedef __mt_bfloat162 nv_bfloat162;
|
@@ -46,8 +46,8 @@ if (GGML_HIP_ROCWMMA_FATTN)
|
|
46
46
|
endif()
|
47
47
|
endif()
|
48
48
|
|
49
|
-
if (${hip_VERSION} VERSION_LESS
|
50
|
-
message(FATAL_ERROR "At least ROCM/HIP
|
49
|
+
if (${hip_VERSION} VERSION_LESS 6.1)
|
50
|
+
message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
|
51
51
|
endif()
|
52
52
|
|
53
53
|
message(STATUS "HIP and hipBLAS found")
|
@@ -113,10 +113,18 @@ if (GGML_HIP_ROCWMMA_FATTN)
|
|
113
113
|
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
|
114
114
|
endif()
|
115
115
|
|
116
|
+
if (NOT GGML_HIP_MMQ_MFMA)
|
117
|
+
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
|
118
|
+
endif()
|
119
|
+
|
116
120
|
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
|
117
121
|
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
118
122
|
endif()
|
119
123
|
|
124
|
+
if (GGML_HIP_EXPORT_METRICS)
|
125
|
+
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
|
126
|
+
endif()
|
127
|
+
|
120
128
|
if (NOT GGML_CUDA_FA)
|
121
129
|
add_compile_definitions(GGML_CUDA_NO_FA)
|
122
130
|
endif()
|
@@ -73,6 +73,35 @@ static inline int ggml_up(int n, int m) {
|
|
73
73
|
return (n + m - 1) & ~(m - 1);
|
74
74
|
}
|
75
75
|
|
76
|
+
// TODO: move to ggml.h? (won't be able to inline)
|
77
|
+
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
78
|
+
if (a->type != b->type) {
|
79
|
+
return false;
|
80
|
+
}
|
81
|
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
82
|
+
if (a->ne[i] != b->ne[i]) {
|
83
|
+
return false;
|
84
|
+
}
|
85
|
+
if (a->nb[i] != b->nb[i]) {
|
86
|
+
return false;
|
87
|
+
}
|
88
|
+
}
|
89
|
+
return true;
|
90
|
+
}
|
91
|
+
|
92
|
+
static bool ggml_op_is_empty(enum ggml_op op) {
|
93
|
+
switch (op) {
|
94
|
+
case GGML_OP_NONE:
|
95
|
+
case GGML_OP_RESHAPE:
|
96
|
+
case GGML_OP_TRANSPOSE:
|
97
|
+
case GGML_OP_VIEW:
|
98
|
+
case GGML_OP_PERMUTE:
|
99
|
+
return true;
|
100
|
+
default:
|
101
|
+
return false;
|
102
|
+
}
|
103
|
+
}
|
104
|
+
|
76
105
|
//
|
77
106
|
// logging
|
78
107
|
//
|
@@ -313,6 +342,10 @@ struct ggml_cgraph {
|
|
313
342
|
// if you need the gradients, get them from the original graph
|
314
343
|
struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
|
315
344
|
|
345
|
+
// ggml-alloc.c: true if the operation can reuse memory from its sources
|
346
|
+
GGML_API bool ggml_op_can_inplace(enum ggml_op op);
|
347
|
+
|
348
|
+
|
316
349
|
// Memory allocation
|
317
350
|
|
318
351
|
GGML_API void * ggml_aligned_malloc(size_t size);
|
@@ -394,6 +427,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|
394
427
|
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
395
428
|
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
396
429
|
|
430
|
+
static inline float ggml_e8m0_to_fp32(uint8_t x) {
|
431
|
+
uint32_t bits; // Stores the raw bit representation of the float
|
432
|
+
|
433
|
+
// Handle special case for minimum exponent (denormalized float)
|
434
|
+
if (x == 0) {
|
435
|
+
// Bit pattern for 2^(-127):
|
436
|
+
// - Sign bit: 0 (positive)
|
437
|
+
// - Exponent: 0 (denormalized number)
|
438
|
+
// - Mantissa: 0x400000 (0.5 in fractional form)
|
439
|
+
// Value = 0.5 * 2^(-126) = 2^(-127)
|
440
|
+
bits = 0x00400000;
|
441
|
+
}
|
442
|
+
// note: disabled as we don't need to handle NaNs
|
443
|
+
//// Handle special case for NaN (all bits set)
|
444
|
+
//else if (x == 0xFF) {
|
445
|
+
// // Standard quiet NaN pattern:
|
446
|
+
// // - Sign bit: 0
|
447
|
+
// // - Exponent: all 1s (0xFF)
|
448
|
+
// // - Mantissa: 0x400000 (quiet NaN flag)
|
449
|
+
// bits = 0x7FC00000;
|
450
|
+
//}
|
451
|
+
// Normalized values (most common case)
|
452
|
+
else {
|
453
|
+
// Construct normalized float by shifting exponent into position:
|
454
|
+
// - Exponent field: 8 bits (positions 30-23)
|
455
|
+
// - Mantissa: 0 (implicit leading 1)
|
456
|
+
// Value = 2^(x - 127)
|
457
|
+
bits = (uint32_t) x << 23;
|
458
|
+
}
|
459
|
+
|
460
|
+
float result; // Final float value
|
461
|
+
// Safely reinterpret bit pattern as float without type-punning issues
|
462
|
+
memcpy(&result, &bits, sizeof(float));
|
463
|
+
return result;
|
464
|
+
}
|
465
|
+
|
466
|
+
// Equal to ggml_e8m0_to_fp32/2
|
467
|
+
// Useful with MXFP4 quantization since the E0M2 values are doubled
|
468
|
+
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
|
469
|
+
uint32_t bits;
|
470
|
+
|
471
|
+
// For x < 2: use precomputed denormal patterns
|
472
|
+
if (x < 2) {
|
473
|
+
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
|
474
|
+
bits = 0x00200000 << x;
|
475
|
+
}
|
476
|
+
// For x >= 2: normalized exponent adjustment
|
477
|
+
else {
|
478
|
+
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
|
479
|
+
bits = (uint32_t)(x - 1) << 23;
|
480
|
+
}
|
481
|
+
// Note: NaNs are not handled here
|
482
|
+
|
483
|
+
float result;
|
484
|
+
memcpy(&result, &bits, sizeof(float));
|
485
|
+
return result;
|
486
|
+
}
|
487
|
+
|
488
|
+
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
|
489
|
+
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
|
490
|
+
|
397
491
|
/**
|
398
492
|
* Converts brain16 to float32.
|
399
493
|
*
|
@@ -493,27 +587,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
|
|
493
587
|
return true;
|
494
588
|
}
|
495
589
|
|
496
|
-
// Returns true if nodes
|
590
|
+
// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[]
|
497
591
|
// and are fusable. Nodes are considered fusable according to this function if:
|
498
592
|
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
|
499
593
|
// - all nodes except the last are a src of the following node.
|
500
594
|
// - all nodes are the same shape.
|
501
595
|
// TODO: Consider allowing GGML_OP_NONE nodes in between
|
502
|
-
static inline bool
|
503
|
-
if (node_idx + num_ops > cgraph->n_nodes) {
|
504
|
-
return false;
|
505
|
-
}
|
506
|
-
|
596
|
+
static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) {
|
507
597
|
for (int i = 0; i < num_ops; ++i) {
|
508
|
-
|
598
|
+
if (node_idxs[i] >= cgraph->n_nodes) {
|
599
|
+
return false;
|
600
|
+
}
|
601
|
+
|
602
|
+
struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
|
509
603
|
if (node->op != ops[i]) {
|
510
604
|
return false;
|
511
605
|
}
|
512
|
-
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph,
|
606
|
+
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
|
513
607
|
return false;
|
514
608
|
}
|
515
609
|
if (i > 0) {
|
516
|
-
struct ggml_tensor * prev = cgraph->nodes[
|
610
|
+
struct ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
|
517
611
|
if (node->src[0] != prev && node->src[1] != prev) {
|
518
612
|
return false;
|
519
613
|
}
|
@@ -525,6 +619,22 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
|
525
619
|
return true;
|
526
620
|
}
|
527
621
|
|
622
|
+
// same as above, for sequential indices starting at node_idx
|
623
|
+
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
|
624
|
+
assert(num_ops < 32);
|
625
|
+
|
626
|
+
if (node_idx + num_ops > cgraph->n_nodes) {
|
627
|
+
return false;
|
628
|
+
}
|
629
|
+
|
630
|
+
int idxs[32];
|
631
|
+
for (int i = 0; i < num_ops; ++i) {
|
632
|
+
idxs[i] = node_idx + i;
|
633
|
+
}
|
634
|
+
|
635
|
+
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
636
|
+
}
|
637
|
+
|
528
638
|
#ifdef __cplusplus
|
529
639
|
}
|
530
640
|
#endif
|
@@ -5,7 +5,12 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
|
5
5
|
message(STATUS "Metal framework found")
|
6
6
|
|
7
7
|
ggml_add_backend_library(ggml-metal
|
8
|
-
ggml-metal.
|
8
|
+
ggml-metal.cpp
|
9
|
+
ggml-metal-device.m
|
10
|
+
ggml-metal-device.cpp
|
11
|
+
ggml-metal-common.cpp
|
12
|
+
ggml-metal-context.m
|
13
|
+
ggml-metal-ops.cpp
|
9
14
|
)
|
10
15
|
|
11
16
|
target_link_libraries(ggml-metal PRIVATE
|
@@ -18,10 +23,6 @@ if (GGML_METAL_NDEBUG)
|
|
18
23
|
add_compile_definitions(GGML_METAL_NDEBUG)
|
19
24
|
endif()
|
20
25
|
|
21
|
-
if (GGML_METAL_USE_BF16)
|
22
|
-
add_compile_definitions(GGML_METAL_USE_BF16)
|
23
|
-
endif()
|
24
|
-
|
25
26
|
# copy metal files to bin directory
|
26
27
|
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
|
27
28
|
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
@@ -71,7 +72,9 @@ else()
|
|
71
72
|
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
|
72
73
|
# note: unfortunately, we have to call it default.metallib instead of ggml.metallib
|
73
74
|
# ref: https://github.com/ggerganov/whisper.cpp/issues/1720
|
74
|
-
|
75
|
+
# note: adding -g causes segmentation fault during compile
|
76
|
+
#set(XC_FLAGS -fno-fast-math -fno-inline -g)
|
77
|
+
set(XC_FLAGS -fno-fast-math -fno-inline)
|
75
78
|
else()
|
76
79
|
set(XC_FLAGS -O3)
|
77
80
|
endif()
|
@@ -90,7 +93,7 @@ else()
|
|
90
93
|
add_custom_command(
|
91
94
|
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
92
95
|
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
|
93
|
-
|
96
|
+
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
94
97
|
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
|
95
98
|
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
|
96
99
|
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
|