whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -0,0 +1,556 @@
|
|
1
|
+
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
|
2
|
+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
3
|
+
#if LOAD_VEC_A == 8
|
4
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
5
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
6
|
+
FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
|
7
|
+
buf_a[buf_idx ] = aa[0].xy;
|
8
|
+
buf_a[buf_idx + 1] = aa[0].zw;
|
9
|
+
buf_a[buf_idx + 2] = aa[1].xy;
|
10
|
+
buf_a[buf_idx + 3] = aa[1].zw;
|
11
|
+
#elif LOAD_VEC_A == 4
|
12
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
13
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
14
|
+
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
|
15
|
+
buf_a[buf_idx ] = aa.xy;
|
16
|
+
buf_a[buf_idx + 1] = aa.zw;
|
17
|
+
#else // LOAD_VEC_BATCH_A == 2
|
18
|
+
const uint idx = pos_a + col * p.stride_a + row * 2;
|
19
|
+
const uint buf_idx = col * SHMEM_STRIDE + row;
|
20
|
+
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
21
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
|
22
|
+
data_a[idx + 1]);
|
23
|
+
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
24
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);
|
25
|
+
} else {
|
26
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
27
|
+
}
|
28
|
+
#endif
|
29
|
+
#elif defined(DATA_A_BF16)
|
30
|
+
#if LOAD_VEC_A == 4
|
31
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
32
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
33
|
+
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
|
34
|
+
buf_a[buf_idx ] = aa.xy;
|
35
|
+
buf_a[buf_idx + 1] = aa.zw;
|
36
|
+
#else // LOAD_VEC_BATCH_A == 2
|
37
|
+
const uint idx = pos_a + col * p.stride_a + row * 2;
|
38
|
+
const uint buf_idx = col * SHMEM_STRIDE + row;
|
39
|
+
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
40
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
|
41
|
+
TO_FLOAT_TYPE(data_a[idx + 1]));
|
42
|
+
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
43
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
44
|
+
} else {
|
45
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
46
|
+
}
|
47
|
+
#endif
|
48
|
+
#elif defined(DATA_A_Q4_0)
|
49
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
50
|
+
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
51
|
+
|
52
|
+
const uint ib = idx / 4;
|
53
|
+
const uint iqs = idx & 0x03;
|
54
|
+
|
55
|
+
const float d = float(data_a_packed16[ib].d);
|
56
|
+
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
57
|
+
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
58
|
+
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
59
|
+
|
60
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
61
|
+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
|
62
|
+
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
|
63
|
+
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
|
64
|
+
#elif defined(DATA_A_Q4_1)
|
65
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
66
|
+
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
67
|
+
|
68
|
+
const uint ib = idx / 4;
|
69
|
+
const uint iqs = idx & 0x03;
|
70
|
+
|
71
|
+
const float d = float(data_a_packed16[ib].d);
|
72
|
+
const float m = float(data_a_packed16[ib].m);
|
73
|
+
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
74
|
+
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
|
75
|
+
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
|
76
|
+
|
77
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
78
|
+
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
|
79
|
+
buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
|
80
|
+
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
|
81
|
+
#elif defined(DATA_A_Q5_0)
|
82
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
83
|
+
const uint buf_idx = col * SHMEM_STRIDE + row;
|
84
|
+
|
85
|
+
const uint ib = idx / 8;
|
86
|
+
const uint iqs = idx & 0x07;
|
87
|
+
|
88
|
+
const float d = float(data_a_packed16[ib].d);
|
89
|
+
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
|
90
|
+
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
91
|
+
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
92
|
+
|
93
|
+
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
94
|
+
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
95
|
+
|
96
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
97
|
+
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
98
|
+
#elif defined(DATA_A_Q5_1)
|
99
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
100
|
+
const uint buf_idx = col * SHMEM_STRIDE + row;
|
101
|
+
|
102
|
+
const uint ib = idx / 8;
|
103
|
+
const uint iqs = idx & 0x07;
|
104
|
+
|
105
|
+
const float d = float(data_a_packed16[ib].d);
|
106
|
+
const float m = float(data_a_packed16[ib].m);
|
107
|
+
const uint uint_qh = data_a_packed16[ib].qh;
|
108
|
+
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
109
|
+
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
110
|
+
|
111
|
+
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
112
|
+
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
113
|
+
|
114
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
115
|
+
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
116
|
+
#elif defined(DATA_A_Q8_0)
|
117
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
118
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
119
|
+
|
120
|
+
const uint ib = idx / 8;
|
121
|
+
const uint iqs = idx & 0x07;
|
122
|
+
|
123
|
+
const float d = float(data_a_packed16[ib].d);
|
124
|
+
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
|
125
|
+
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
126
|
+
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
127
|
+
|
128
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
129
|
+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
130
|
+
#elif defined(DATA_A_Q2_K)
|
131
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
132
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
133
|
+
|
134
|
+
const uint ib = idx / 128; // 2 values per idx
|
135
|
+
const uint iqs = idx % 128; // 0..127
|
136
|
+
|
137
|
+
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
138
|
+
const uint scalesi = iqs / 8; // 0..15
|
139
|
+
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
140
|
+
|
141
|
+
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
142
|
+
const uint scales = data_a[ib].scales[scalesi];
|
143
|
+
const vec2 d = vec2(data_a[ib].d);
|
144
|
+
|
145
|
+
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
146
|
+
|
147
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
148
|
+
#elif defined(DATA_A_Q3_K)
|
149
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
150
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
151
|
+
|
152
|
+
const uint ib = idx / 128; // 2 values per idx
|
153
|
+
const uint iqs = idx % 128; // 0..127
|
154
|
+
|
155
|
+
const uint n = iqs / 64; // 0,1
|
156
|
+
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
157
|
+
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
|
158
|
+
const uint j = (iqs % 64) / 4; // 0..3
|
159
|
+
const uint is = iqs / 8; // 0..15
|
160
|
+
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
|
161
|
+
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
162
|
+
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
163
|
+
|
164
|
+
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
|
165
|
+
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
166
|
+
const float dl = float(data_a[ib].d) * float(us - 32);
|
167
|
+
|
168
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
|
169
|
+
dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
170
|
+
#elif defined(DATA_A_Q4_K)
|
171
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
172
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
173
|
+
|
174
|
+
const uint ib = idx / 128; // 2 values per idx
|
175
|
+
const uint iqs = idx % 128; // 0..127
|
176
|
+
|
177
|
+
const uint n = iqs / 32; // 0,1,2,3
|
178
|
+
const uint b = (iqs % 32) / 16; // 0,1
|
179
|
+
const uint is = 2 * n + b; // 0..7
|
180
|
+
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
181
|
+
|
182
|
+
const vec2 loadd = vec2(data_a[ib].d);
|
183
|
+
|
184
|
+
const uint scidx0 = (is < 4) ? is : (is + 4);
|
185
|
+
const uint scidx1 = (is < 4) ? is : (is - 4);
|
186
|
+
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
187
|
+
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
188
|
+
const uint mbidx0 = is + 4;
|
189
|
+
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
190
|
+
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
191
|
+
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
192
|
+
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
193
|
+
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
194
|
+
|
195
|
+
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
196
|
+
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
197
|
+
|
198
|
+
const float d = loadd.x * sc;
|
199
|
+
const float m = -loadd.y * mbyte;
|
200
|
+
|
201
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
|
202
|
+
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
203
|
+
#elif defined(DATA_A_Q5_K)
|
204
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
205
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
206
|
+
|
207
|
+
const uint ib = idx / 128; // 2 values per idx
|
208
|
+
const uint iqs = idx % 128; // 0..127
|
209
|
+
|
210
|
+
const uint n = iqs / 32; // 0,1,2,3
|
211
|
+
const uint b = (iqs % 32) / 16; // 0,1
|
212
|
+
const uint is = 2 * n + b; // 0..7
|
213
|
+
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
214
|
+
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
|
215
|
+
|
216
|
+
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
217
|
+
|
218
|
+
const vec2 loadd = vec2(data_a[ib].d);
|
219
|
+
|
220
|
+
const uint scidx0 = (is < 4) ? is : (is + 4);
|
221
|
+
const uint scidx1 = (is < 4) ? is : (is - 4);
|
222
|
+
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
223
|
+
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
224
|
+
const uint mbidx0 = is + 4;
|
225
|
+
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
226
|
+
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
227
|
+
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
228
|
+
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
229
|
+
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
230
|
+
|
231
|
+
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
232
|
+
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
233
|
+
|
234
|
+
const float d = loadd.x * sc;
|
235
|
+
const float m = -loadd.y * mbyte;
|
236
|
+
|
237
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
|
238
|
+
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
239
|
+
#elif defined(DATA_A_Q6_K)
|
240
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
241
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
242
|
+
|
243
|
+
const uint ib = idx / 128; // 2 values per idx
|
244
|
+
const uint iqs = idx % 128; // 0..127
|
245
|
+
|
246
|
+
const uint n = iqs / 64; // 0,1
|
247
|
+
const uint b = (iqs % 64) / 32; // 0,1
|
248
|
+
const uint is_b = (iqs % 16) / 8; // 0,1
|
249
|
+
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
250
|
+
const uint is = 8 * n + qhshift + is_b; // 0..15
|
251
|
+
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
252
|
+
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
253
|
+
|
254
|
+
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
255
|
+
|
256
|
+
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
|
257
|
+
dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
258
|
+
#elif defined(DATA_A_IQ1_S)
|
259
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
260
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
261
|
+
|
262
|
+
const uint ib = idx / 32; // 8 values per idx
|
263
|
+
const uint ib32 = (idx % 32) / 4; // 0..7
|
264
|
+
const uint ib8 = idx % 32;
|
265
|
+
|
266
|
+
const float d = float(data_a[ib].d);
|
267
|
+
const uint qh = data_a[ib].qh[ib32];
|
268
|
+
const uint qs = data_a[ib].qs[ib8];
|
269
|
+
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
|
270
|
+
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
271
|
+
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
272
|
+
|
273
|
+
[[unroll]] for (int k = 0; k < 4; ++k) {
|
274
|
+
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
275
|
+
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
276
|
+
}
|
277
|
+
#elif defined(DATA_A_IQ1_M)
|
278
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
279
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
280
|
+
|
281
|
+
const uint ib = idx / 32; // 8 values per idx
|
282
|
+
const uint ib8 = idx % 32;
|
283
|
+
const uint ib16 = ib8 / 2;
|
284
|
+
|
285
|
+
const uint16_t[4] scales = data_a[ib].scales;
|
286
|
+
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
287
|
+
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
288
|
+
const uint sc = scales[ib8 / 8];
|
289
|
+
const uint qs = data_a[ib].qs[ib8];
|
290
|
+
const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
|
291
|
+
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
292
|
+
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
293
|
+
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
294
|
+
|
295
|
+
[[unroll]] for (int k = 0; k < 4; ++k) {
|
296
|
+
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
297
|
+
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
298
|
+
}
|
299
|
+
#elif defined(DATA_A_IQ2_XXS)
|
300
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
301
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
302
|
+
|
303
|
+
const uint ib = idx / 32; // 8 values per idx
|
304
|
+
const uint ib32 = (idx % 32) / 4; // 0..7
|
305
|
+
const uint ib8 = idx % 4;
|
306
|
+
|
307
|
+
const float d = float(data_a[ib].d);
|
308
|
+
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
309
|
+
const uint signs = pack32(u8vec4(
|
310
|
+
data_a[ib].qs[8*ib32 + 4],
|
311
|
+
data_a[ib].qs[8*ib32 + 5],
|
312
|
+
data_a[ib].qs[8*ib32 + 6],
|
313
|
+
data_a[ib].qs[8*ib32 + 7]
|
314
|
+
));
|
315
|
+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
316
|
+
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
317
|
+
const uint sign = sign7 | (bitCount(sign7) << 7);
|
318
|
+
const uvec2 grid = iq2xxs_grid[qs];
|
319
|
+
const vec4 grid0 = vec4(unpack8(grid.x));
|
320
|
+
const vec4 grid1 = vec4(unpack8(grid.y));
|
321
|
+
|
322
|
+
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
323
|
+
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
324
|
+
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
325
|
+
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
326
|
+
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
327
|
+
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
328
|
+
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
329
|
+
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
330
|
+
#elif defined(DATA_A_IQ2_XS)
|
331
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
332
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
333
|
+
|
334
|
+
const uint ib = idx / 32; // 8 values per idx
|
335
|
+
const uint ib32 = (idx % 32) / 4; // 0..7
|
336
|
+
const uint ib8 = idx % 4; // 0..3
|
337
|
+
|
338
|
+
const float d = float(data_a[ib].d);
|
339
|
+
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
340
|
+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
341
|
+
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
342
|
+
const uint sign7 = qs >> 9;
|
343
|
+
const uint sign = sign7 | (bitCount(sign7) << 7);
|
344
|
+
const uvec2 grid = iq2xs_grid[qs & 511];
|
345
|
+
const vec4 grid0 = vec4(unpack8(grid.x));
|
346
|
+
const vec4 grid1 = vec4(unpack8(grid.y));
|
347
|
+
|
348
|
+
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
349
|
+
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
350
|
+
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
351
|
+
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
352
|
+
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
353
|
+
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
354
|
+
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
355
|
+
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
356
|
+
#elif defined(DATA_A_IQ2_S)
|
357
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
358
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
359
|
+
|
360
|
+
const uint ib = idx / 32; // 8 values per idx
|
361
|
+
const uint ib8 = idx % 32; // 0..31
|
362
|
+
const uint ib32 = ib8 / 4; // 0..7
|
363
|
+
|
364
|
+
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
365
|
+
const uint qs = data_a[ib].qs[ib8];
|
366
|
+
const uint qh = data_a[ib].qh[ib32];
|
367
|
+
const uint qhshift = 2 * (ib8 % 4);
|
368
|
+
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
369
|
+
|
370
|
+
const float d = float(data_a[ib].d);
|
371
|
+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
372
|
+
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
373
|
+
const vec4 grid0 = vec4(unpack8(grid.x));
|
374
|
+
const vec4 grid1 = vec4(unpack8(grid.y));
|
375
|
+
|
376
|
+
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
377
|
+
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
378
|
+
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
379
|
+
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
380
|
+
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
381
|
+
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
382
|
+
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
383
|
+
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
384
|
+
#elif defined(DATA_A_IQ3_XXS)
|
385
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
386
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
387
|
+
|
388
|
+
const uint ib = idx / 64; // 4 values per idx
|
389
|
+
const uint iqs = idx % 64; // 0..63
|
390
|
+
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
391
|
+
|
392
|
+
const float d = float(data_a[ib].d);
|
393
|
+
const uint qs = data_a[ib].qs[iqs];
|
394
|
+
const uint signs = pack32(u8vec4(
|
395
|
+
data_a[ib].qs[is+0],
|
396
|
+
data_a[ib].qs[is+1],
|
397
|
+
data_a[ib].qs[is+2],
|
398
|
+
data_a[ib].qs[is+3]
|
399
|
+
));
|
400
|
+
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
401
|
+
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
402
|
+
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
403
|
+
const uint grid = iq3xxs_grid[qs];
|
404
|
+
const vec4 v = db * vec4(unpack8(grid));
|
405
|
+
|
406
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
407
|
+
(sign & 2) != 0 ? -v.y : v.y);
|
408
|
+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
409
|
+
(sign & 8) != 0 ? -v.w : v.w);
|
410
|
+
#elif defined(DATA_A_IQ3_S)
|
411
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
412
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
413
|
+
|
414
|
+
const uint ib = idx / 64; // 4 values per idx
|
415
|
+
const uint iqs = idx % 64; // 0..63
|
416
|
+
const uint iqh = iqs / 8;
|
417
|
+
|
418
|
+
const float d = float(data_a[ib].d);
|
419
|
+
const uint qs = data_a[ib].qs[iqs];
|
420
|
+
const uint qh = data_a[ib].qh[iqh];
|
421
|
+
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
422
|
+
const uint scale = data_a[ib].scales[iqs / 16];
|
423
|
+
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
424
|
+
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
425
|
+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
426
|
+
const vec4 v = db * vec4(unpack8(grid));
|
427
|
+
|
428
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
429
|
+
(sign & 2) != 0 ? -v.y : v.y);
|
430
|
+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
431
|
+
(sign & 8) != 0 ? -v.w : v.w);
|
432
|
+
#elif defined(DATA_A_IQ4_XS)
|
433
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
434
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
435
|
+
|
436
|
+
const uint ib = idx / 128; // 2 values per idx
|
437
|
+
const uint ib32 = (idx % 128) / 16; // 0..7
|
438
|
+
const uint iq = 16 * ib32 + 2 * (idx % 8);
|
439
|
+
|
440
|
+
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
|
441
|
+
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
|
442
|
+
const uint qshift = (idx & 8) >> 1;
|
443
|
+
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
|
444
|
+
qs = (qs >> qshift) & uint8_t(0xF);
|
445
|
+
|
446
|
+
const float d = float(data_a[ib].d);
|
447
|
+
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
|
448
|
+
|
449
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
450
|
+
#elif defined(DATA_A_IQ4_NL)
|
451
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
452
|
+
const uint buf_idx = col * SHMEM_STRIDE + row;
|
453
|
+
|
454
|
+
const uint ib = idx / 8;
|
455
|
+
const uint iqs = idx & 0x07;
|
456
|
+
|
457
|
+
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
458
|
+
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
459
|
+
|
460
|
+
buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],
|
461
|
+
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
|
462
|
+
buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
|
463
|
+
kvalues_iq4nl[vui >> 12]);
|
464
|
+
#elif defined(DATA_A_MXFP4)
|
465
|
+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
466
|
+
const uint buf_idx = col * SHMEM_STRIDE + row;
|
467
|
+
|
468
|
+
const uint ib = idx / 8;
|
469
|
+
const uint iqs = (idx & 0x07) * 2;
|
470
|
+
|
471
|
+
const float d = e8m0_to_fp32(data_a[ib].e);
|
472
|
+
const uint vui = uint(data_a[ib].qs[iqs]);
|
473
|
+
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
474
|
+
|
475
|
+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d,
|
476
|
+
kvalues_mxfp4[vui2 & 0xF] * d);
|
477
|
+
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d,
|
478
|
+
kvalues_mxfp4[vui2 >> 4] * d);
|
479
|
+
#endif
|
480
|
+
}
|
481
|
+
|
482
|
+
#if !defined(MUL_MAT_ID)
|
483
|
+
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
|
484
|
+
#if LOAD_VEC_B == 8
|
485
|
+
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
486
|
+
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
487
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
488
|
+
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
489
|
+
buf_b[buf_idx + 0] = bb[0].xy;
|
490
|
+
buf_b[buf_idx + 1] = bb[0].zw;
|
491
|
+
buf_b[buf_idx + 2] = bb[1].xy;
|
492
|
+
buf_b[buf_idx + 3] = bb[1].zw;
|
493
|
+
#elif LOAD_VEC_B == 4
|
494
|
+
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
495
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
496
|
+
#if defined(DATA_B_BF16)
|
497
|
+
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
498
|
+
#else
|
499
|
+
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
500
|
+
#endif
|
501
|
+
buf_b[buf_idx + 0] = bb.xy;
|
502
|
+
buf_b[buf_idx + 1] = bb.zw;
|
503
|
+
#else // LOAD_VEC_BATCH_B == 2
|
504
|
+
const uint idx = pos_b + col * p.stride_b + row * 2;
|
505
|
+
const uint buf_idx = col * SHMEM_STRIDE + row;
|
506
|
+
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
507
|
+
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
508
|
+
TO_FLOAT_TYPE(data_b[idx + 1]));
|
509
|
+
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
510
|
+
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
511
|
+
} else {
|
512
|
+
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
513
|
+
}
|
514
|
+
#endif
|
515
|
+
}
|
516
|
+
#else
|
517
|
+
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
|
518
|
+
#if LOAD_VEC_B == 8
|
519
|
+
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
520
|
+
const u16vec2 row_idx = row_ids[col];
|
521
|
+
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
522
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
523
|
+
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
524
|
+
buf_b[buf_idx + 0] = bb[0].xy;
|
525
|
+
buf_b[buf_idx + 1] = bb[0].zw;
|
526
|
+
buf_b[buf_idx + 2] = bb[1].xy;
|
527
|
+
buf_b[buf_idx + 3] = bb[1].zw;
|
528
|
+
#elif LOAD_VEC_B == 4
|
529
|
+
const u16vec2 row_idx = row_ids[col];
|
530
|
+
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
531
|
+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
532
|
+
#if defined(DATA_B_BF16)
|
533
|
+
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
534
|
+
#else
|
535
|
+
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
536
|
+
#endif
|
537
|
+
buf_b[buf_idx + 0] = bb.xy;
|
538
|
+
buf_b[buf_idx + 1] = bb.zw;
|
539
|
+
#else // LOAD_VEC_BATCH_B == 2
|
540
|
+
const uint row_i = ic * BN + col;
|
541
|
+
const uint buf_idx = col * SHMEM_STRIDE + row;
|
542
|
+
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
543
|
+
const u16vec2 row_idx = row_ids[col];
|
544
|
+
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
545
|
+
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
546
|
+
TO_FLOAT_TYPE(data_b[idx + 1]));
|
547
|
+
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
548
|
+
const u16vec2 row_idx = row_ids[col];
|
549
|
+
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
550
|
+
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
551
|
+
} else {
|
552
|
+
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
553
|
+
}
|
554
|
+
#endif
|
555
|
+
}
|
556
|
+
#endif
|
@@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
|
28
28
|
#if defined(A_TYPE_PACKED32)
|
29
29
|
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
30
30
|
#endif
|
31
|
-
layout (binding = 1) readonly buffer B {
|
31
|
+
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
|
32
32
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
33
33
|
|
34
34
|
#ifdef MUL_MAT_ID
|
@@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
|
98
98
|
#endif
|
99
99
|
|
100
100
|
#define LOAD_VEC_A (4 * QUANT_R)
|
101
|
-
#define LOAD_VEC_B
|
101
|
+
#define LOAD_VEC_B 16
|
102
102
|
|
103
103
|
#ifdef MUL_MAT_ID
|
104
104
|
shared u16vec2 row_ids[4096];
|
@@ -270,15 +270,22 @@ void main() {
|
|
270
270
|
const uint iqs = idx & 0x7;
|
271
271
|
#else
|
272
272
|
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
273
|
+
const uint ib_outer = ib / 4;
|
274
|
+
const uint ib_inner = ib % 4;
|
275
|
+
|
273
276
|
const uint iqs = loadr_b;
|
274
277
|
#endif
|
275
278
|
|
276
279
|
const uint buf_ib = loadc_b + l;
|
277
280
|
|
278
281
|
if (iqs == 0) {
|
279
|
-
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[
|
282
|
+
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
280
283
|
}
|
281
|
-
|
284
|
+
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
285
|
+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
|
286
|
+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
|
287
|
+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
|
288
|
+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
|
282
289
|
}
|
283
290
|
|
284
291
|
barrier();
|
@@ -349,7 +356,7 @@ void main() {
|
|
349
356
|
cache_b_qs[cc * (BK / 4) + idx_k]);
|
350
357
|
}
|
351
358
|
|
352
|
-
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
|
359
|
+
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
|
353
360
|
}
|
354
361
|
}
|
355
362
|
}
|
@@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
|
16
16
|
(vui >> 4) & 0x0F0F0F0F);
|
17
17
|
}
|
18
18
|
|
19
|
-
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
20
|
-
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8
|
19
|
+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
20
|
+
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
21
21
|
}
|
22
22
|
#endif
|
23
23
|
|
@@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
|
29
29
|
(vui >> 4) & 0x0F0F0F0F);
|
30
30
|
}
|
31
31
|
|
32
|
-
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
33
|
-
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
32
|
+
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
33
|
+
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
34
34
|
}
|
35
35
|
#endif
|
36
36
|
|
@@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
|
50
50
|
return i32vec2(v0, v1);
|
51
51
|
}
|
52
52
|
|
53
|
-
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
54
|
-
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16
|
53
|
+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
54
|
+
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
55
55
|
}
|
56
56
|
#endif
|
57
57
|
|
@@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
|
69
69
|
return i32vec2(v0, v1);
|
70
70
|
}
|
71
71
|
|
72
|
-
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
73
|
-
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
72
|
+
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
73
|
+
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
74
74
|
}
|
75
75
|
#endif
|
76
76
|
|
@@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) {
|
|
81
81
|
data_a[ib].qs[iqs * 2 + 1]));
|
82
82
|
}
|
83
83
|
|
84
|
-
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
84
|
+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
85
85
|
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
86
86
|
}
|
87
87
|
#endif
|
@@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) {
|
|
92
92
|
}
|
93
93
|
#endif
|
94
94
|
|
95
|
+
#if defined(DATA_A_MXFP4)
|
96
|
+
FLOAT_TYPE get_d(uint ib) {
|
97
|
+
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
|
98
|
+
}
|
99
|
+
#endif
|
100
|
+
|
95
101
|
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
96
102
|
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
97
103
|
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|