whispercpp 1.3.2 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +59 -27
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +154 -35
- data/ext/sources/examples/addon.node/index.js +10 -5
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +29 -18
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +7 -4
- data/ext/sources/examples/command/command.cpp +58 -32
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +21 -17
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +193 -35
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +10 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
- data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
- data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
- data/ext/sources/examples/talk-llama/llama-context.h +68 -32
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
- data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
- data/ext/sources/examples/talk-llama/llama-model.h +87 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
- data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -17
- data/ext/sources/examples/talk-llama/llama.h +176 -151
- data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +106 -33
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +18 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +365 -21
- data/ext/sources/ggml/src/CMakeLists.txt +98 -25
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
- data/ext/sources/ggml/src/ggml-common.h +21 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
- data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
- data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
- data/ext/sources/ggml/src/ggml-impl.h +229 -175
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +117 -24
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +802 -142
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +32 -4
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +241 -215
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +57 -2
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +75 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/{tests → test}/test_params.rb +8 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +246 -191
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,161 @@
|
|
1
|
+
#include "conv2d-dw.cuh"
|
2
|
+
|
3
|
+
struct conv_params {
|
4
|
+
int in_w, in_h;
|
5
|
+
int out_w, out_h;
|
6
|
+
int kernel_w, kernel_h;
|
7
|
+
int stride_x, stride_y;
|
8
|
+
int padding_x, padding_y;
|
9
|
+
int dilation_x, dilation_y;
|
10
|
+
int channels, batches;
|
11
|
+
};
|
12
|
+
|
13
|
+
struct kernel_bounds {
|
14
|
+
int y_min, y_max;
|
15
|
+
int x_min, x_max;
|
16
|
+
};
|
17
|
+
|
18
|
+
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
|
19
|
+
kernel_bounds bounds;
|
20
|
+
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
|
21
|
+
bounds.y_max =
|
22
|
+
min(params.kernel_h,
|
23
|
+
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
|
24
|
+
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
|
25
|
+
bounds.x_max =
|
26
|
+
min(params.kernel_w,
|
27
|
+
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
|
28
|
+
return bounds;
|
29
|
+
}
|
30
|
+
|
31
|
+
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
|
32
|
+
return out_coord * stride + kern_coord * dilation - padding;
|
33
|
+
}
|
34
|
+
|
35
|
+
struct whcn_layout {
|
36
|
+
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
|
37
|
+
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
|
38
|
+
}
|
39
|
+
|
40
|
+
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
|
41
|
+
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
|
42
|
+
}
|
43
|
+
|
44
|
+
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
|
45
|
+
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
|
46
|
+
y * params.out_w + x;
|
47
|
+
}
|
48
|
+
|
49
|
+
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
|
50
|
+
int & out_x) {
|
51
|
+
out_x = global_idx % params.out_w;
|
52
|
+
out_y = (global_idx / params.out_w) % params.out_h;
|
53
|
+
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
|
54
|
+
n = global_idx / (params.out_w * params.out_h * params.channels);
|
55
|
+
}
|
56
|
+
};
|
57
|
+
|
58
|
+
struct cwhn_layout {
|
59
|
+
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
|
60
|
+
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
|
61
|
+
}
|
62
|
+
|
63
|
+
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
|
64
|
+
return (ky * params.kernel_w + kx) * params.channels + c;
|
65
|
+
}
|
66
|
+
|
67
|
+
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
|
68
|
+
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
|
69
|
+
x * params.channels + c;
|
70
|
+
}
|
71
|
+
|
72
|
+
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
|
73
|
+
int & out_x) {
|
74
|
+
c = global_idx % params.channels;
|
75
|
+
out_x = (global_idx / params.channels) % params.out_w;
|
76
|
+
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
|
77
|
+
n = global_idx / (params.channels * params.out_w * params.out_h);
|
78
|
+
}
|
79
|
+
};
|
80
|
+
|
81
|
+
template <typename T, typename Layout>
|
82
|
+
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
|
83
|
+
const int in_w, const int in_h, const int out_w, const int out_h,
|
84
|
+
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
|
85
|
+
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
|
86
|
+
const int channels, const int batches) {
|
87
|
+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
88
|
+
const int total_elements = batches * channels * out_h * out_w;
|
89
|
+
|
90
|
+
if (global_idx >= total_elements) {
|
91
|
+
return;
|
92
|
+
}
|
93
|
+
|
94
|
+
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
|
95
|
+
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
|
96
|
+
|
97
|
+
int batch_idx, channel_idx, out_y_idx, out_x_idx;
|
98
|
+
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
|
99
|
+
|
100
|
+
T accumulator = 0;
|
101
|
+
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
|
102
|
+
|
103
|
+
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
|
104
|
+
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
|
105
|
+
|
106
|
+
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
|
107
|
+
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
|
108
|
+
|
109
|
+
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
|
110
|
+
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
|
111
|
+
|
112
|
+
accumulator += input_val * kernel_val;
|
113
|
+
}
|
114
|
+
}
|
115
|
+
|
116
|
+
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
|
117
|
+
}
|
118
|
+
|
119
|
+
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
120
|
+
const ggml_tensor * kernel = dst->src[0];
|
121
|
+
const ggml_tensor * input = dst->src[1];
|
122
|
+
|
123
|
+
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
124
|
+
const float * w_d = (const float *) kernel->data;
|
125
|
+
const float * x_d = (const float *) input->data;
|
126
|
+
float * y_d = (float *) dst->data;
|
127
|
+
|
128
|
+
const int32_t * p = (const int32_t *) dst->op_params;
|
129
|
+
const int stride_x = p[0];
|
130
|
+
const int stride_y = p[1];
|
131
|
+
const int padding_x = p[2];
|
132
|
+
const int padding_y = p[3];
|
133
|
+
const int dilation_x = p[4];
|
134
|
+
const int dilation_y = p[5];
|
135
|
+
|
136
|
+
const int in_w = input->ne[0];
|
137
|
+
const int in_h = input->ne[1];
|
138
|
+
const int kernel_w = kernel->ne[0];
|
139
|
+
const int kernel_h = kernel->ne[1];
|
140
|
+
const int out_w = dst->ne[0];
|
141
|
+
const int out_h = dst->ne[1];
|
142
|
+
const int channels = dst->ne[2];
|
143
|
+
const int batches = dst->ne[3];
|
144
|
+
|
145
|
+
cudaStream_t st = ctx.stream();
|
146
|
+
|
147
|
+
const int total = batches * channels * out_h * out_w;
|
148
|
+
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
|
149
|
+
|
150
|
+
if (ggml_is_contiguous(input)) {
|
151
|
+
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
|
152
|
+
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
|
153
|
+
dilation_x, dilation_y, channels, batches);
|
154
|
+
} else if (ggml_is_contiguous_channels(input)) {
|
155
|
+
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
|
156
|
+
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
|
157
|
+
dilation_x, dilation_y, channels, batches);
|
158
|
+
} else {
|
159
|
+
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
|
160
|
+
}
|
161
|
+
}
|
@@ -0,0 +1,91 @@
|
|
1
|
+
#include <algorithm>
|
2
|
+
|
3
|
+
#include "conv2d-transpose.cuh"
|
4
|
+
#include "ggml.h"
|
5
|
+
|
6
|
+
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
|
7
|
+
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
|
8
|
+
const int out_h, const int kernel_w, const int kernel_h, const int stride,
|
9
|
+
const int c_in, const int c_out, const int batches) {
|
10
|
+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
11
|
+
|
12
|
+
const int total_elements = out_w * out_h * c_out * batches;
|
13
|
+
|
14
|
+
if (global_idx >= total_elements) {
|
15
|
+
return;
|
16
|
+
}
|
17
|
+
|
18
|
+
const int out_x_idx = global_idx % out_w;
|
19
|
+
const int out_y_idx = (global_idx / out_w) % out_h;
|
20
|
+
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
|
21
|
+
const int n_idx = global_idx / (out_w * out_h * c_out);
|
22
|
+
|
23
|
+
float accumulator = 0;
|
24
|
+
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
|
25
|
+
|
26
|
+
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
|
27
|
+
for (int kh = 0; kh < kernel_h; ++kh) {
|
28
|
+
int in_y = out_y_idx - kh;
|
29
|
+
if (in_y < 0 || in_y % stride) continue;
|
30
|
+
in_y /= stride;
|
31
|
+
if (in_y >= in_h) continue;
|
32
|
+
|
33
|
+
for (int kw = 0; kw < kernel_w; ++kw) {
|
34
|
+
int in_x = out_x_idx - kw;
|
35
|
+
if (in_x < 0 || in_x % stride) continue;
|
36
|
+
in_x /= stride;
|
37
|
+
if (in_x >= in_w) continue;
|
38
|
+
|
39
|
+
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
|
40
|
+
const int kernel_idx =
|
41
|
+
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
|
42
|
+
|
43
|
+
float input_val = input[input_idx];
|
44
|
+
half kern_val = kernel[kernel_idx];
|
45
|
+
|
46
|
+
accumulator += input_val * (float) kern_val;
|
47
|
+
}
|
48
|
+
}
|
49
|
+
}
|
50
|
+
|
51
|
+
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
|
52
|
+
}
|
53
|
+
|
54
|
+
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
|
55
|
+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
56
|
+
const ggml_tensor * kernel = dst->src[0];
|
57
|
+
const ggml_tensor * input = dst->src[1];
|
58
|
+
|
59
|
+
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
60
|
+
|
61
|
+
const float * input_data = (const float *) input->data;
|
62
|
+
float * output_data = (float *) dst->data;
|
63
|
+
const half * kernel_data = (const half *) kernel->data;
|
64
|
+
|
65
|
+
const int input_w = input->ne[0];
|
66
|
+
const int input_h = input->ne[1];
|
67
|
+
const int output_w = dst->ne[0];
|
68
|
+
const int output_h = dst->ne[1];
|
69
|
+
const int channels_in = input->ne[2];
|
70
|
+
const int channels_out = kernel->ne[2];
|
71
|
+
const int kernel_w = kernel->ne[0];
|
72
|
+
const int kernel_h = kernel->ne[1];
|
73
|
+
const int stride = dst->op_params[0];
|
74
|
+
const int batches = input->ne[3];
|
75
|
+
|
76
|
+
GGML_ASSERT(channels_in == kernel->ne[3]);
|
77
|
+
GGML_ASSERT(stride > 0);
|
78
|
+
|
79
|
+
cudaStream_t st = ctx.stream();
|
80
|
+
|
81
|
+
GGML_ASSERT(ggml_is_contiguous(input));
|
82
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
83
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
84
|
+
|
85
|
+
const int total = (output_w * output_h * channels_out * batches);
|
86
|
+
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
|
87
|
+
|
88
|
+
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
|
89
|
+
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
|
90
|
+
channels_in, channels_out, batches);
|
91
|
+
}
|
@@ -0,0 +1,166 @@
|
|
1
|
+
#include "conv2d.cuh"
|
2
|
+
#include "convert.cuh"
|
3
|
+
|
4
|
+
struct conv_params {
|
5
|
+
const int64_t IW, IH;
|
6
|
+
const int64_t OW, OH;
|
7
|
+
const int64_t KW, KH;
|
8
|
+
const int64_t ST_X, ST_Y;
|
9
|
+
const int64_t PD_X, PD_Y;
|
10
|
+
const int64_t DL_X, DL_Y;
|
11
|
+
const int64_t IC, OC;
|
12
|
+
const int64_t B;
|
13
|
+
const int64_t TOTAL;
|
14
|
+
};
|
15
|
+
|
16
|
+
struct kernel_bounds {
|
17
|
+
int64_t y_min, y_max;
|
18
|
+
int64_t x_min, x_max;
|
19
|
+
};
|
20
|
+
|
21
|
+
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
|
22
|
+
return (a > b) ? a : b;
|
23
|
+
}
|
24
|
+
|
25
|
+
__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
|
26
|
+
return (a < b) ? a : b;
|
27
|
+
}
|
28
|
+
|
29
|
+
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
|
30
|
+
kernel_bounds bounds;
|
31
|
+
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
|
32
|
+
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
|
33
|
+
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
|
34
|
+
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
|
35
|
+
return bounds;
|
36
|
+
}
|
37
|
+
|
38
|
+
__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
|
39
|
+
int64_t kern_coord,
|
40
|
+
int64_t stride,
|
41
|
+
int64_t dilation,
|
42
|
+
int64_t padding) {
|
43
|
+
return out_coord * stride + kern_coord * dilation - padding;
|
44
|
+
}
|
45
|
+
|
46
|
+
struct whcn_layout {
|
47
|
+
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
|
48
|
+
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
|
49
|
+
}
|
50
|
+
|
51
|
+
__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
|
52
|
+
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
|
53
|
+
}
|
54
|
+
|
55
|
+
__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
|
56
|
+
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
|
57
|
+
}
|
58
|
+
|
59
|
+
__device__ static void unpack_indices(int64_t global_idx,
|
60
|
+
const conv_params & P,
|
61
|
+
int64_t & n,
|
62
|
+
int64_t & c,
|
63
|
+
int64_t & out_y,
|
64
|
+
int64_t & out_x) {
|
65
|
+
out_x = global_idx % P.OW;
|
66
|
+
out_y = (global_idx / P.OW) % P.OH;
|
67
|
+
c = (global_idx / (P.OW * P.OH)) % P.OC;
|
68
|
+
n = global_idx / (P.OW * P.OH * P.OC);
|
69
|
+
}
|
70
|
+
};
|
71
|
+
|
72
|
+
template <typename T, typename Layout>
|
73
|
+
static __global__ void conv2d_kernel(const float * __restrict__ input,
|
74
|
+
const T * __restrict__ kernel,
|
75
|
+
float * __restrict__ output,
|
76
|
+
const conv_params P) {
|
77
|
+
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
78
|
+
|
79
|
+
if (global_idx >= P.TOTAL) {
|
80
|
+
return;
|
81
|
+
}
|
82
|
+
|
83
|
+
int64_t n, c_out, out_y, out_x;
|
84
|
+
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
|
85
|
+
|
86
|
+
float acc = 0.0f;
|
87
|
+
|
88
|
+
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
|
89
|
+
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
|
90
|
+
|
91
|
+
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
|
92
|
+
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
|
93
|
+
|
94
|
+
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
|
95
|
+
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
|
96
|
+
|
97
|
+
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
|
98
|
+
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
|
99
|
+
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
|
100
|
+
}
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
// [N, OC, OH, OW]
|
105
|
+
output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
|
106
|
+
}
|
107
|
+
|
108
|
+
template <typename T>
|
109
|
+
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
110
|
+
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
|
111
|
+
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
|
112
|
+
}
|
113
|
+
|
114
|
+
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
115
|
+
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
|
116
|
+
}
|
117
|
+
|
118
|
+
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
119
|
+
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
|
120
|
+
}
|
121
|
+
|
122
|
+
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
123
|
+
const ggml_tensor * kernel = dst->src[0];
|
124
|
+
const ggml_tensor * input = dst->src[1];
|
125
|
+
float * K_D = (float *) kernel->data;
|
126
|
+
const float * X_D = (const float *) input->data;
|
127
|
+
float * Y_D = (float *) dst->data;
|
128
|
+
|
129
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
130
|
+
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
|
131
|
+
|
132
|
+
// same number of input channels
|
133
|
+
GGML_ASSERT(input->ne[2] == kernel->ne[2]);
|
134
|
+
|
135
|
+
cudaStream_t st = ctx.stream();
|
136
|
+
|
137
|
+
const int32_t * p = (const int32_t *) dst->op_params;
|
138
|
+
const int ST_X = p[0]; // stride_x
|
139
|
+
const int ST_Y = p[1]; // stride_y
|
140
|
+
const int PD_X = p[2]; // padding_x
|
141
|
+
const int PD_Y = p[3]; // padding_y
|
142
|
+
const int DL_X = p[4]; // dilation_x
|
143
|
+
const int DL_Y = p[5]; // dilation_y
|
144
|
+
|
145
|
+
// No cwhn
|
146
|
+
GGML_ASSERT(p[6] == false);
|
147
|
+
|
148
|
+
const int IW = input->ne[0]; // input_w
|
149
|
+
const int IH = input->ne[1]; // input_h
|
150
|
+
const int OW = dst->ne[0]; // output_w
|
151
|
+
const int OH = dst->ne[1]; // output_h
|
152
|
+
const int KW = kernel->ne[0]; // kernel_w
|
153
|
+
const int KH = kernel->ne[1]; // kernel_h
|
154
|
+
const int IC = input->ne[2]; // input_channels
|
155
|
+
const int OC = kernel->ne[3]; // ouptut_chanles
|
156
|
+
const int B = input->ne[3]; // n_batches
|
157
|
+
|
158
|
+
const int64_t total = B * OC * OH * OW;
|
159
|
+
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
|
160
|
+
|
161
|
+
if (kernel->type == GGML_TYPE_F16) {
|
162
|
+
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
|
163
|
+
} else {
|
164
|
+
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
|
165
|
+
}
|
166
|
+
}
|
@@ -6,24 +6,33 @@
|
|
6
6
|
#define CUDA_Q8_0_NE_ALIGN 2048
|
7
7
|
|
8
8
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
9
|
-
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
|
10
|
-
|
9
|
+
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
|
10
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
11
|
+
const int64_t s01, const int64_t s02, const int64_t s03) {
|
12
|
+
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
|
11
13
|
|
12
|
-
if (
|
14
|
+
if (i00 >= ne00) {
|
13
15
|
return;
|
14
16
|
}
|
15
17
|
|
16
|
-
const int64_t
|
17
|
-
const int64_t
|
18
|
-
const int64_t
|
18
|
+
const int64_t i01 = blockIdx.y;
|
19
|
+
const int64_t i02 = blockIdx.z % ne02;
|
20
|
+
const int64_t i03 = blockIdx.z / ne02;
|
21
|
+
|
22
|
+
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
23
|
+
|
24
|
+
const int64_t ib = ibx0 + i00/qk; // block index
|
25
|
+
const int64_t iqs = (i00%qk)/qr; // quant index
|
26
|
+
const int64_t iybs = i00 - i00%qk; // y block start index
|
19
27
|
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
20
28
|
|
21
29
|
// dequantize
|
22
|
-
|
30
|
+
float2 v;
|
23
31
|
dequantize_kernel(vx, ib, iqs, v);
|
24
32
|
|
25
|
-
|
26
|
-
y[
|
33
|
+
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
|
34
|
+
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
35
|
+
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
27
36
|
}
|
28
37
|
|
29
38
|
template <bool need_check>
|
@@ -62,9 +71,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
|
|
62
71
|
y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
|
63
72
|
}
|
64
73
|
#else
|
65
|
-
|
66
|
-
GGML_UNUSED(y);
|
67
|
-
GGML_UNUSED(k);
|
74
|
+
GGML_UNUSED_VARS(vx, y, k);
|
68
75
|
NO_DEVICE_CODE;
|
69
76
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
70
77
|
}
|
@@ -456,10 +463,36 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
|
|
456
463
|
}
|
457
464
|
}
|
458
465
|
|
466
|
+
template<typename dst_t>
|
467
|
+
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
468
|
+
|
469
|
+
const int64_t i = blockIdx.x;
|
470
|
+
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
|
471
|
+
|
472
|
+
const int64_t tid = threadIdx.x;
|
473
|
+
const int64_t il = tid/8; // 0...3
|
474
|
+
const int64_t ib = tid%8; // 0...7
|
475
|
+
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
476
|
+
const uint8_t * q4 = x[ib].qs + 4*il;
|
477
|
+
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
|
478
|
+
for (int j = 0; j < 4; ++j) {
|
479
|
+
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
|
480
|
+
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
|
481
|
+
}
|
482
|
+
}
|
483
|
+
|
484
|
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
485
|
+
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
486
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
487
|
+
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
488
|
+
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
|
489
|
+
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
490
|
+
(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
491
|
+
}
|
492
|
+
|
459
493
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
460
|
-
static void
|
461
|
-
|
462
|
-
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
494
|
+
static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
495
|
+
dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
|
463
496
|
}
|
464
497
|
|
465
498
|
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
@@ -571,6 +604,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
|
|
571
604
|
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
572
605
|
}
|
573
606
|
|
607
|
+
template<typename dst_t>
|
608
|
+
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
609
|
+
const int nb = (k + QK_K - 1) / QK_K;
|
610
|
+
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
|
611
|
+
}
|
612
|
+
|
574
613
|
template <typename src_t, typename dst_t>
|
575
614
|
static __global__ void convert_unary(
|
576
615
|
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
@@ -589,7 +628,7 @@ static __global__ void convert_unary(
|
|
589
628
|
|
590
629
|
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
591
630
|
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
|
592
|
-
y[iy] =
|
631
|
+
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
593
632
|
}
|
594
633
|
|
595
634
|
template <typename src_t, typename dst_t>
|
@@ -624,14 +663,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
624
663
|
case GGML_TYPE_Q4_1:
|
625
664
|
return dequantize_row_q4_1_cuda;
|
626
665
|
case GGML_TYPE_Q5_0:
|
627
|
-
return
|
666
|
+
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
628
667
|
case GGML_TYPE_Q5_1:
|
629
|
-
return
|
668
|
+
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
630
669
|
case GGML_TYPE_Q8_0:
|
631
670
|
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
|
632
671
|
return dequantize_block_q8_0_f16_cuda;
|
633
672
|
}
|
634
|
-
return
|
673
|
+
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
635
674
|
case GGML_TYPE_Q2_K:
|
636
675
|
return dequantize_row_q2_K_cuda;
|
637
676
|
case GGML_TYPE_Q3_K:
|
@@ -660,6 +699,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
660
699
|
return dequantize_row_iq4_xs_cuda;
|
661
700
|
case GGML_TYPE_IQ3_S:
|
662
701
|
return dequantize_row_iq3_s_cuda;
|
702
|
+
case GGML_TYPE_MXFP4:
|
703
|
+
return dequantize_row_mxfp4_cuda;
|
663
704
|
case GGML_TYPE_F32:
|
664
705
|
return convert_unary_cont_cuda<float>;
|
665
706
|
case GGML_TYPE_BF16:
|
@@ -676,11 +717,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
676
717
|
case GGML_TYPE_Q4_1:
|
677
718
|
return dequantize_row_q4_1_cuda;
|
678
719
|
case GGML_TYPE_Q5_0:
|
679
|
-
return
|
720
|
+
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
680
721
|
case GGML_TYPE_Q5_1:
|
681
|
-
return
|
722
|
+
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
682
723
|
case GGML_TYPE_Q8_0:
|
683
|
-
return
|
724
|
+
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
684
725
|
case GGML_TYPE_Q2_K:
|
685
726
|
return dequantize_row_q2_K_cuda;
|
686
727
|
case GGML_TYPE_Q3_K:
|
@@ -709,6 +750,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
709
750
|
return dequantize_row_iq4_xs_cuda;
|
710
751
|
case GGML_TYPE_IQ3_S:
|
711
752
|
return dequantize_row_iq3_s_cuda;
|
753
|
+
case GGML_TYPE_MXFP4:
|
754
|
+
return dequantize_row_mxfp4_cuda;
|
712
755
|
case GGML_TYPE_F16:
|
713
756
|
return convert_unary_cont_cuda<half>;
|
714
757
|
case GGML_TYPE_BF16:
|
@@ -722,9 +765,61 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
|
|
722
765
|
switch (type) {
|
723
766
|
case GGML_TYPE_F32:
|
724
767
|
return convert_unary_cuda<float>;
|
768
|
+
case GGML_TYPE_Q4_0:
|
769
|
+
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
770
|
+
case GGML_TYPE_Q4_1:
|
771
|
+
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
772
|
+
case GGML_TYPE_Q5_0:
|
773
|
+
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
774
|
+
case GGML_TYPE_Q5_1:
|
775
|
+
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
776
|
+
case GGML_TYPE_Q8_0:
|
777
|
+
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
725
778
|
case GGML_TYPE_BF16:
|
726
779
|
return convert_unary_cuda<nv_bfloat16>;
|
727
780
|
default:
|
728
781
|
return nullptr;
|
729
782
|
}
|
730
783
|
}
|
784
|
+
|
785
|
+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
|
786
|
+
switch (type) {
|
787
|
+
case GGML_TYPE_F32:
|
788
|
+
return convert_unary_cuda<float, nv_bfloat16>;
|
789
|
+
case GGML_TYPE_Q4_0:
|
790
|
+
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
791
|
+
case GGML_TYPE_Q4_1:
|
792
|
+
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
793
|
+
case GGML_TYPE_Q5_0:
|
794
|
+
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
795
|
+
case GGML_TYPE_Q5_1:
|
796
|
+
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
797
|
+
case GGML_TYPE_Q8_0:
|
798
|
+
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
799
|
+
case GGML_TYPE_F16:
|
800
|
+
return convert_unary_cuda<half, nv_bfloat16>;
|
801
|
+
default:
|
802
|
+
return nullptr;
|
803
|
+
}
|
804
|
+
}
|
805
|
+
|
806
|
+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
|
807
|
+
switch (type) {
|
808
|
+
case GGML_TYPE_F16:
|
809
|
+
return convert_unary_cuda<half, float>;
|
810
|
+
case GGML_TYPE_Q4_0:
|
811
|
+
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
812
|
+
case GGML_TYPE_Q4_1:
|
813
|
+
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
814
|
+
case GGML_TYPE_Q5_0:
|
815
|
+
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
816
|
+
case GGML_TYPE_Q5_1:
|
817
|
+
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
818
|
+
case GGML_TYPE_Q8_0:
|
819
|
+
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
820
|
+
case GGML_TYPE_BF16:
|
821
|
+
return convert_unary_cuda<nv_bfloat16, float>;
|
822
|
+
default:
|
823
|
+
return nullptr;
|
824
|
+
}
|
825
|
+
}
|