whispercpp 1.3.3 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_whisper_params.c +55 -25
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +4 -2
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/server/server.cpp +24 -13
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
- data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
- data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
- data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
- data/ext/sources/examples/talk-llama/llama-context.h +44 -29
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
- data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
- data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
- data/ext/sources/examples/talk-llama/llama-model.h +60 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
- data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
- data/ext/sources/examples/talk-llama/llama.cpp +65 -10
- data/ext/sources/examples/talk-llama/llama.h +95 -177
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +59 -31
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +17 -1
- data/ext/sources/ggml/include/ggml-cpu.h +1 -1
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +221 -16
- data/ext/sources/ggml/src/CMakeLists.txt +17 -2
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
- data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
- data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
- data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- data/ext/sources/ggml/src/ggml-impl.h +119 -9
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +478 -98
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/src/whisper.cpp +23 -46
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/lib/whisper/model/uri.rb +1 -1
- data/sig/whisper.rbs +7 -0
- data/test/test_params.rb +8 -0
- data/test/test_whisper.rb +1 -1
- data/whispercpp.gemspec +1 -1
- metadata +164 -157
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -8,6 +8,7 @@
|
|
8
8
|
#include "vec.h"
|
9
9
|
|
10
10
|
#include <float.h>
|
11
|
+
#include <algorithm>
|
11
12
|
|
12
13
|
// ggml_compute_forward_dup
|
13
14
|
|
@@ -40,13 +41,15 @@ static void ggml_compute_forward_dup_same_cont(
|
|
40
41
|
}
|
41
42
|
}
|
42
43
|
|
43
|
-
|
44
|
+
template<typename src_t, typename dst_t>
|
45
|
+
static void ggml_compute_forward_dup_flt(
|
44
46
|
const ggml_compute_params * params,
|
45
47
|
ggml_tensor * dst) {
|
46
48
|
|
47
49
|
const ggml_tensor * src0 = dst->src[0];
|
48
50
|
|
49
51
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
52
|
+
GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
|
50
53
|
|
51
54
|
GGML_TENSOR_UNARY_OP_LOCALS
|
52
55
|
|
@@ -61,6 +64,7 @@ static void ggml_compute_forward_dup_f16(
|
|
61
64
|
const int ir0 = dr * ith;
|
62
65
|
const int ir1 = MIN(ir0 + dr, nr);
|
63
66
|
|
67
|
+
// case: type & row size equal
|
64
68
|
if (src0->type == dst->type &&
|
65
69
|
ne00 == ne0 &&
|
66
70
|
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
@@ -79,11 +83,11 @@ static void ggml_compute_forward_dup_f16(
|
|
79
83
|
return;
|
80
84
|
}
|
81
85
|
|
82
|
-
//
|
83
|
-
|
86
|
+
// case: dst tensor is contiguous
|
84
87
|
if (ggml_is_contiguous(dst)) {
|
85
|
-
if (nb00 == sizeof(
|
86
|
-
if (
|
88
|
+
if (nb00 == sizeof(src_t)) {
|
89
|
+
if constexpr (std::is_same_v<dst_t, src_t>) {
|
90
|
+
// same type
|
87
91
|
size_t id = 0;
|
88
92
|
const size_t rs = ne00 * nb00;
|
89
93
|
char * dst_ptr = (char *) dst->data;
|
@@ -99,91 +103,46 @@ static void ggml_compute_forward_dup_f16(
|
|
99
103
|
id += rs * (ne01 - ir1);
|
100
104
|
}
|
101
105
|
}
|
102
|
-
} else
|
106
|
+
} else {
|
107
|
+
// casting between non-quantized types
|
103
108
|
size_t id = 0;
|
104
|
-
|
109
|
+
dst_t * dst_ptr = (dst_t *) dst->data;
|
105
110
|
|
106
111
|
for (int i03 = 0; i03 < ne03; i03++) {
|
107
112
|
for (int i02 = 0; i02 < ne02; i02++) {
|
108
113
|
id += ne00 * ir0;
|
109
114
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
110
|
-
const
|
115
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
111
116
|
for (int i00 = 0; i00 < ne00; i00++) {
|
112
|
-
|
117
|
+
float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
118
|
+
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
113
119
|
id++;
|
114
120
|
}
|
115
121
|
}
|
116
122
|
id += ne00 * (ne01 - ir1);
|
117
123
|
}
|
118
124
|
}
|
119
|
-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
120
|
-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
121
|
-
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
122
|
-
|
123
|
-
size_t id = 0;
|
124
|
-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
125
|
-
char * dst_ptr = (char *) dst->data;
|
126
|
-
|
127
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
128
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
129
|
-
id += rs * ir0;
|
130
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
131
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
132
|
-
|
133
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
134
|
-
src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
135
|
-
}
|
136
|
-
|
137
|
-
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
138
|
-
id += rs;
|
139
|
-
}
|
140
|
-
id += rs * (ne01 - ir1);
|
141
|
-
}
|
142
|
-
}
|
143
|
-
} else {
|
144
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
145
125
|
}
|
146
126
|
} else {
|
147
127
|
//printf("%s: this is not optimal - fix me\n", __func__);
|
148
128
|
|
149
|
-
|
150
|
-
|
151
|
-
float * dst_ptr = (float *) dst->data;
|
152
|
-
|
153
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
154
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
155
|
-
id += ne00 * ir0;
|
156
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
157
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
158
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
159
|
-
|
160
|
-
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
|
161
|
-
id++;
|
162
|
-
}
|
163
|
-
}
|
164
|
-
id += ne00 * (ne01 - ir1);
|
165
|
-
}
|
166
|
-
}
|
167
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
168
|
-
size_t id = 0;
|
169
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
129
|
+
size_t id = 0;
|
130
|
+
dst_t * dst_ptr = (dst_t *) dst->data;
|
170
131
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
132
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
133
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
134
|
+
id += ne00 * ir0;
|
135
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
136
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
137
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
177
138
|
|
178
|
-
|
179
|
-
|
180
|
-
|
139
|
+
float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
|
140
|
+
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
141
|
+
id++;
|
181
142
|
}
|
182
|
-
id += ne00 * (ne01 - ir1);
|
183
143
|
}
|
144
|
+
id += ne00 * (ne01 - ir1);
|
184
145
|
}
|
185
|
-
} else {
|
186
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
187
146
|
}
|
188
147
|
}
|
189
148
|
return;
|
@@ -195,7 +154,7 @@ static void ggml_compute_forward_dup_f16(
|
|
195
154
|
int64_t i12 = 0;
|
196
155
|
int64_t i13 = 0;
|
197
156
|
|
198
|
-
if (
|
157
|
+
if constexpr (std::is_same_v<dst_t, src_t>) {
|
199
158
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
200
159
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
201
160
|
i10 += ne00 * ir0;
|
@@ -216,7 +175,7 @@ static void ggml_compute_forward_dup_f16(
|
|
216
175
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
217
176
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
218
177
|
|
219
|
-
memcpy(dst_ptr, src0_ptr, sizeof(
|
178
|
+
memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
|
220
179
|
|
221
180
|
if (++i10 == ne00) {
|
222
181
|
i10 = 0;
|
@@ -247,7 +206,8 @@ static void ggml_compute_forward_dup_f16(
|
|
247
206
|
}
|
248
207
|
}
|
249
208
|
}
|
250
|
-
|
209
|
+
|
210
|
+
} else {
|
251
211
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
252
212
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
253
213
|
i10 += ne00 * ir0;
|
@@ -268,7 +228,8 @@ static void ggml_compute_forward_dup_f16(
|
|
268
228
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
269
229
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
270
230
|
|
271
|
-
|
231
|
+
float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
|
232
|
+
*(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
|
272
233
|
|
273
234
|
if (++i10 == ne0) {
|
274
235
|
i10 = 0;
|
@@ -299,18 +260,19 @@ static void ggml_compute_forward_dup_f16(
|
|
299
260
|
}
|
300
261
|
}
|
301
262
|
}
|
302
|
-
} else {
|
303
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
304
263
|
}
|
305
264
|
}
|
306
265
|
|
307
|
-
|
266
|
+
|
267
|
+
template<typename src_t>
|
268
|
+
static void ggml_compute_forward_dup_to_q(
|
308
269
|
const ggml_compute_params * params,
|
309
270
|
ggml_tensor * dst) {
|
310
271
|
|
311
272
|
const ggml_tensor * src0 = dst->src[0];
|
312
273
|
|
313
274
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
275
|
+
GGML_ASSERT(!ggml_is_quantized(src0->type));
|
314
276
|
|
315
277
|
GGML_TENSOR_UNARY_OP_LOCALS
|
316
278
|
|
@@ -325,11 +287,73 @@ static void ggml_compute_forward_dup_bf16(
|
|
325
287
|
const int ir0 = dr * ith;
|
326
288
|
const int ir1 = MIN(ir0 + dr, nr);
|
327
289
|
|
290
|
+
if (ggml_is_contiguous(dst) &&
|
291
|
+
nb00 == sizeof(src_t) &&
|
292
|
+
ggml_get_type_traits_cpu(dst->type)->from_float) {
|
293
|
+
// casting non-quantized types --> intermediate f32 --> quantized
|
294
|
+
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
295
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
296
|
+
|
297
|
+
size_t id = 0;
|
298
|
+
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
299
|
+
char * dst_ptr = (char *) dst->data;
|
300
|
+
|
301
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
302
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
303
|
+
id += rs * ir0;
|
304
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
305
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
306
|
+
|
307
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
308
|
+
src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
309
|
+
}
|
310
|
+
|
311
|
+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
312
|
+
id += rs;
|
313
|
+
}
|
314
|
+
id += rs * (ne01 - ir1);
|
315
|
+
}
|
316
|
+
}
|
317
|
+
} else {
|
318
|
+
// printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
|
319
|
+
GGML_ABORT("not implemented");
|
320
|
+
}
|
321
|
+
}
|
322
|
+
|
323
|
+
// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
|
324
|
+
static void ggml_compute_forward_dup_bytes(
|
325
|
+
const ggml_compute_params * params,
|
326
|
+
ggml_tensor * dst) {
|
327
|
+
const ggml_tensor * src0 = dst->src[0];
|
328
|
+
|
329
|
+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
330
|
+
GGML_ASSERT(src0->type == dst->type);
|
331
|
+
|
332
|
+
GGML_TENSOR_UNARY_OP_LOCALS;
|
333
|
+
|
334
|
+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
|
335
|
+
ggml_compute_forward_dup_same_cont(params, dst);
|
336
|
+
return;
|
337
|
+
}
|
338
|
+
|
339
|
+
const size_t type_size = ggml_type_size(src0->type);
|
340
|
+
|
341
|
+
const int ith = params->ith; // thread index
|
342
|
+
const int nth = params->nth; // number of threads
|
343
|
+
|
344
|
+
// parallelize by rows
|
345
|
+
const int nr = ne01;
|
346
|
+
// number of rows per thread
|
347
|
+
const int dr = (nr + nth - 1) / nth;
|
348
|
+
// row range for this thread
|
349
|
+
const int ir0 = dr * ith;
|
350
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
351
|
+
|
328
352
|
if (src0->type == dst->type &&
|
329
|
-
|
330
|
-
nb00 ==
|
353
|
+
ggml_are_same_shape(src0, dst) &&
|
354
|
+
nb00 == type_size && nb0 == type_size) {
|
331
355
|
// copy by rows
|
332
|
-
const size_t rs = ne00
|
356
|
+
const size_t rs = ggml_row_size(src0->type, ne00);
|
333
357
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
334
358
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
335
359
|
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
@@ -343,765 +367,110 @@ static void ggml_compute_forward_dup_bf16(
|
|
343
367
|
return;
|
344
368
|
}
|
345
369
|
|
346
|
-
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
347
|
-
|
348
370
|
if (ggml_is_contiguous(dst)) {
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
const size_t rs = ne00 * nb00;
|
353
|
-
char * dst_ptr = (char *) dst->data;
|
371
|
+
size_t id = 0;
|
372
|
+
char * dst_ptr = (char *) dst->data;
|
373
|
+
const size_t rs = ne00 * type_size;
|
354
374
|
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
id += rs
|
375
|
+
if (nb00 == type_size) {
|
376
|
+
// src0 is contigous on first dimension, copy by rows
|
377
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
378
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
379
|
+
id += rs * ir0;
|
380
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
381
|
+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
382
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
383
|
+
id += rs;
|
364
384
|
}
|
385
|
+
id += rs * (ne01 - ir1);
|
365
386
|
}
|
366
|
-
}
|
367
|
-
|
368
|
-
|
387
|
+
}
|
388
|
+
} else {
|
389
|
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
369
390
|
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
id++;
|
378
|
-
}
|
379
|
-
}
|
380
|
-
id += ne00 * (ne01 - ir1);
|
381
|
-
}
|
382
|
-
}
|
383
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
384
|
-
size_t id = 0;
|
385
|
-
float * dst_ptr = (float *) dst->data;
|
391
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
392
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
393
|
+
id += rs * ir0;
|
394
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
395
|
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
396
|
+
const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
|
397
|
+
memcpy(dst_ptr + id, src0_ptr, type_size);
|
386
398
|
|
387
|
-
|
388
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
389
|
-
id += ne00 * ir0;
|
390
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
391
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
392
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
393
|
-
dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
|
394
|
-
id++;
|
395
|
-
}
|
399
|
+
id += type_size;
|
396
400
|
}
|
397
|
-
id += ne00 * (ne01 - ir1);
|
398
401
|
}
|
402
|
+
id += rs * (ne01 - ir1);
|
399
403
|
}
|
400
|
-
}
|
401
|
-
|
402
|
-
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
404
|
+
}
|
405
|
+
}
|
403
406
|
|
404
|
-
|
405
|
-
|
406
|
-
char * dst_ptr = (char *) dst->data;
|
407
|
+
return;
|
408
|
+
}
|
407
409
|
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
410
|
+
// dst counters
|
411
|
+
int64_t k10 = 0;
|
412
|
+
int64_t i11 = 0;
|
413
|
+
int64_t i12 = 0;
|
414
|
+
int64_t i13 = 0;
|
413
415
|
|
414
|
-
|
415
|
-
|
416
|
-
|
416
|
+
// number of blocks in a row
|
417
|
+
const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
|
418
|
+
const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
|
417
419
|
|
418
|
-
|
419
|
-
|
420
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
421
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
422
|
+
k10 += nk00 * ir0;
|
423
|
+
while (k10 >= nk0) {
|
424
|
+
k10 -= nk0;
|
425
|
+
if (++i11 == ne1) {
|
426
|
+
i11 = 0;
|
427
|
+
if (++i12 == ne2) {
|
428
|
+
i12 = 0;
|
429
|
+
if (++i13 == ne3) {
|
430
|
+
i13 = 0;
|
420
431
|
}
|
421
|
-
id += rs * (ne01 - ir1);
|
422
432
|
}
|
423
433
|
}
|
424
|
-
} else {
|
425
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
426
434
|
}
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
size_t id = 0;
|
432
|
-
float * dst_ptr = (float *) dst->data;
|
433
|
-
|
434
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
435
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
436
|
-
id += ne00 * ir0;
|
437
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
438
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
439
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
440
|
-
|
441
|
-
dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
|
442
|
-
id++;
|
443
|
-
}
|
444
|
-
}
|
445
|
-
id += ne00 * (ne01 - ir1);
|
446
|
-
}
|
447
|
-
}
|
448
|
-
} else if (dst->type == GGML_TYPE_BF16) {
|
449
|
-
size_t id = 0;
|
450
|
-
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
|
435
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
436
|
+
for (int64_t k00 = 0; k00 < nk00; k00++) {
|
437
|
+
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
438
|
+
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
451
439
|
|
452
|
-
|
453
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
454
|
-
id += ne00 * ir0;
|
455
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
456
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
457
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
440
|
+
memcpy(dst_ptr, src0_ptr, type_size);
|
458
441
|
|
459
|
-
|
460
|
-
|
442
|
+
if (++k10 == nk0) {
|
443
|
+
k10 = 0;
|
444
|
+
if (++i11 == ne1) {
|
445
|
+
i11 = 0;
|
446
|
+
if (++i12 == ne2) {
|
447
|
+
i12 = 0;
|
448
|
+
if (++i13 == ne3) {
|
449
|
+
i13 = 0;
|
450
|
+
}
|
461
451
|
}
|
462
452
|
}
|
463
|
-
id += ne00 * (ne01 - ir1);
|
464
453
|
}
|
465
454
|
}
|
466
|
-
}
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
|
478
|
-
id++;
|
479
|
-
}
|
455
|
+
}
|
456
|
+
k10 += nk00 * (ne01 - ir1);
|
457
|
+
while (k10 >= nk0) {
|
458
|
+
k10 -= nk0;
|
459
|
+
if (++i11 == ne1) {
|
460
|
+
i11 = 0;
|
461
|
+
if (++i12 == ne2) {
|
462
|
+
i12 = 0;
|
463
|
+
if (++i13 == ne3) {
|
464
|
+
i13 = 0;
|
480
465
|
}
|
481
|
-
id += ne00 * (ne01 - ir1);
|
482
466
|
}
|
483
467
|
}
|
484
|
-
} else {
|
485
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
486
468
|
}
|
487
469
|
}
|
488
|
-
return;
|
489
470
|
}
|
471
|
+
}
|
490
472
|
|
491
|
-
|
492
|
-
int64_t i10 = 0;
|
493
|
-
int64_t i11 = 0;
|
494
|
-
int64_t i12 = 0;
|
495
|
-
int64_t i13 = 0;
|
496
|
-
|
497
|
-
if (dst->type == GGML_TYPE_BF16) {
|
498
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
499
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
500
|
-
i10 += ne00 * ir0;
|
501
|
-
while (i10 >= ne0) {
|
502
|
-
i10 -= ne0;
|
503
|
-
if (++i11 == ne1) {
|
504
|
-
i11 = 0;
|
505
|
-
if (++i12 == ne2) {
|
506
|
-
i12 = 0;
|
507
|
-
if (++i13 == ne3) {
|
508
|
-
i13 = 0;
|
509
|
-
}
|
510
|
-
}
|
511
|
-
}
|
512
|
-
}
|
513
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
514
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
515
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
516
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
517
|
-
|
518
|
-
memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
|
519
|
-
|
520
|
-
if (++i10 == ne00) {
|
521
|
-
i10 = 0;
|
522
|
-
if (++i11 == ne01) {
|
523
|
-
i11 = 0;
|
524
|
-
if (++i12 == ne02) {
|
525
|
-
i12 = 0;
|
526
|
-
if (++i13 == ne03) {
|
527
|
-
i13 = 0;
|
528
|
-
}
|
529
|
-
}
|
530
|
-
}
|
531
|
-
}
|
532
|
-
}
|
533
|
-
}
|
534
|
-
i10 += ne00 * (ne01 - ir1);
|
535
|
-
while (i10 >= ne0) {
|
536
|
-
i10 -= ne0;
|
537
|
-
if (++i11 == ne1) {
|
538
|
-
i11 = 0;
|
539
|
-
if (++i12 == ne2) {
|
540
|
-
i12 = 0;
|
541
|
-
if (++i13 == ne3) {
|
542
|
-
i13 = 0;
|
543
|
-
}
|
544
|
-
}
|
545
|
-
}
|
546
|
-
}
|
547
|
-
}
|
548
|
-
}
|
549
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
550
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
551
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
552
|
-
i10 += ne00 * ir0;
|
553
|
-
while (i10 >= ne0) {
|
554
|
-
i10 -= ne0;
|
555
|
-
if (++i11 == ne1) {
|
556
|
-
i11 = 0;
|
557
|
-
if (++i12 == ne2) {
|
558
|
-
i12 = 0;
|
559
|
-
if (++i13 == ne3) {
|
560
|
-
i13 = 0;
|
561
|
-
}
|
562
|
-
}
|
563
|
-
}
|
564
|
-
}
|
565
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
566
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
567
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
568
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
569
|
-
|
570
|
-
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
|
571
|
-
|
572
|
-
if (++i10 == ne0) {
|
573
|
-
i10 = 0;
|
574
|
-
if (++i11 == ne1) {
|
575
|
-
i11 = 0;
|
576
|
-
if (++i12 == ne2) {
|
577
|
-
i12 = 0;
|
578
|
-
if (++i13 == ne3) {
|
579
|
-
i13 = 0;
|
580
|
-
}
|
581
|
-
}
|
582
|
-
}
|
583
|
-
}
|
584
|
-
}
|
585
|
-
}
|
586
|
-
i10 += ne00 * (ne01 - ir1);
|
587
|
-
while (i10 >= ne0) {
|
588
|
-
i10 -= ne0;
|
589
|
-
if (++i11 == ne1) {
|
590
|
-
i11 = 0;
|
591
|
-
if (++i12 == ne2) {
|
592
|
-
i12 = 0;
|
593
|
-
if (++i13 == ne3) {
|
594
|
-
i13 = 0;
|
595
|
-
}
|
596
|
-
}
|
597
|
-
}
|
598
|
-
}
|
599
|
-
}
|
600
|
-
}
|
601
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
602
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
603
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
604
|
-
i10 += ne00 * ir0;
|
605
|
-
while (i10 >= ne0) {
|
606
|
-
i10 -= ne0;
|
607
|
-
if (++i11 == ne1) {
|
608
|
-
i11 = 0;
|
609
|
-
if (++i12 == ne2) {
|
610
|
-
i12 = 0;
|
611
|
-
if (++i13 == ne3) {
|
612
|
-
i13 = 0;
|
613
|
-
}
|
614
|
-
}
|
615
|
-
}
|
616
|
-
}
|
617
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
618
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
619
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
620
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
621
|
-
|
622
|
-
*(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
|
623
|
-
|
624
|
-
if (++i10 == ne0) {
|
625
|
-
i10 = 0;
|
626
|
-
if (++i11 == ne1) {
|
627
|
-
i11 = 0;
|
628
|
-
if (++i12 == ne2) {
|
629
|
-
i12 = 0;
|
630
|
-
if (++i13 == ne3) {
|
631
|
-
i13 = 0;
|
632
|
-
}
|
633
|
-
}
|
634
|
-
}
|
635
|
-
}
|
636
|
-
}
|
637
|
-
}
|
638
|
-
i10 += ne00 * (ne01 - ir1);
|
639
|
-
while (i10 >= ne0) {
|
640
|
-
i10 -= ne0;
|
641
|
-
if (++i11 == ne1) {
|
642
|
-
i11 = 0;
|
643
|
-
if (++i12 == ne2) {
|
644
|
-
i12 = 0;
|
645
|
-
if (++i13 == ne3) {
|
646
|
-
i13 = 0;
|
647
|
-
}
|
648
|
-
}
|
649
|
-
}
|
650
|
-
}
|
651
|
-
}
|
652
|
-
}
|
653
|
-
} else {
|
654
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
655
|
-
}
|
656
|
-
}
|
657
|
-
|
658
|
-
static void ggml_compute_forward_dup_f32(
|
659
|
-
const ggml_compute_params * params,
|
660
|
-
ggml_tensor * dst) {
|
661
|
-
|
662
|
-
const ggml_tensor * src0 = dst->src[0];
|
663
|
-
|
664
|
-
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
665
|
-
|
666
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
667
|
-
|
668
|
-
const int ith = params->ith; // thread index
|
669
|
-
const int nth = params->nth; // number of threads
|
670
|
-
|
671
|
-
// parallelize by rows
|
672
|
-
const int nr = ne01;
|
673
|
-
// number of rows per thread
|
674
|
-
const int dr = (nr + nth - 1) / nth;
|
675
|
-
// row range for this thread
|
676
|
-
const int ir0 = dr * ith;
|
677
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
678
|
-
|
679
|
-
if (src0->type == dst->type &&
|
680
|
-
ne00 == ne0 &&
|
681
|
-
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
682
|
-
// copy by rows
|
683
|
-
const size_t rs = ne00*nb00;
|
684
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
685
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
686
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
687
|
-
memcpy(
|
688
|
-
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
689
|
-
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
690
|
-
rs);
|
691
|
-
}
|
692
|
-
}
|
693
|
-
}
|
694
|
-
return;
|
695
|
-
}
|
696
|
-
|
697
|
-
if (ggml_is_contiguous(dst)) {
|
698
|
-
// TODO: simplify
|
699
|
-
if (nb00 == sizeof(float)) {
|
700
|
-
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
701
|
-
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
702
|
-
|
703
|
-
size_t id = 0;
|
704
|
-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
705
|
-
char * dst_ptr = (char *) dst->data;
|
706
|
-
|
707
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
708
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
709
|
-
id += rs * ir0;
|
710
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
711
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
712
|
-
from_float(src0_ptr, dst_ptr + id, ne00);
|
713
|
-
id += rs;
|
714
|
-
}
|
715
|
-
id += rs * (ne01 - ir1);
|
716
|
-
}
|
717
|
-
}
|
718
|
-
} else {
|
719
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
720
|
-
}
|
721
|
-
} else {
|
722
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
723
|
-
|
724
|
-
if (dst->type == GGML_TYPE_F32) {
|
725
|
-
size_t id = 0;
|
726
|
-
float * dst_ptr = (float *) dst->data;
|
727
|
-
|
728
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
729
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
730
|
-
id += ne00 * ir0;
|
731
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
732
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
733
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
734
|
-
|
735
|
-
dst_ptr[id] = *src0_ptr;
|
736
|
-
id++;
|
737
|
-
}
|
738
|
-
}
|
739
|
-
id += ne00 * (ne01 - ir1);
|
740
|
-
}
|
741
|
-
}
|
742
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
743
|
-
size_t id = 0;
|
744
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
745
|
-
|
746
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
747
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
748
|
-
id += ne00 * ir0;
|
749
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
750
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
751
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
752
|
-
|
753
|
-
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
|
754
|
-
id++;
|
755
|
-
}
|
756
|
-
}
|
757
|
-
id += ne00 * (ne01 - ir1);
|
758
|
-
}
|
759
|
-
}
|
760
|
-
} else if (dst->type == GGML_TYPE_BF16) {
|
761
|
-
size_t id = 0;
|
762
|
-
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
|
763
|
-
|
764
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
765
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
766
|
-
id += ne00 * ir0;
|
767
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
768
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
769
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
770
|
-
|
771
|
-
dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
|
772
|
-
id++;
|
773
|
-
}
|
774
|
-
}
|
775
|
-
id += ne00 * (ne01 - ir1);
|
776
|
-
}
|
777
|
-
}
|
778
|
-
} else {
|
779
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
780
|
-
}
|
781
|
-
}
|
782
|
-
|
783
|
-
return;
|
784
|
-
}
|
785
|
-
|
786
|
-
// dst counters
|
787
|
-
|
788
|
-
int64_t i10 = 0;
|
789
|
-
int64_t i11 = 0;
|
790
|
-
int64_t i12 = 0;
|
791
|
-
int64_t i13 = 0;
|
792
|
-
|
793
|
-
if (dst->type == GGML_TYPE_F32) {
|
794
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
795
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
796
|
-
i10 += ne00 * ir0;
|
797
|
-
while (i10 >= ne0) {
|
798
|
-
i10 -= ne0;
|
799
|
-
if (++i11 == ne1) {
|
800
|
-
i11 = 0;
|
801
|
-
if (++i12 == ne2) {
|
802
|
-
i12 = 0;
|
803
|
-
if (++i13 == ne3) {
|
804
|
-
i13 = 0;
|
805
|
-
}
|
806
|
-
}
|
807
|
-
}
|
808
|
-
}
|
809
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
810
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
811
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
812
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
813
|
-
|
814
|
-
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
815
|
-
|
816
|
-
if (++i10 == ne0) {
|
817
|
-
i10 = 0;
|
818
|
-
if (++i11 == ne1) {
|
819
|
-
i11 = 0;
|
820
|
-
if (++i12 == ne2) {
|
821
|
-
i12 = 0;
|
822
|
-
if (++i13 == ne3) {
|
823
|
-
i13 = 0;
|
824
|
-
}
|
825
|
-
}
|
826
|
-
}
|
827
|
-
}
|
828
|
-
}
|
829
|
-
}
|
830
|
-
i10 += ne00 * (ne01 - ir1);
|
831
|
-
while (i10 >= ne0) {
|
832
|
-
i10 -= ne0;
|
833
|
-
if (++i11 == ne1) {
|
834
|
-
i11 = 0;
|
835
|
-
if (++i12 == ne2) {
|
836
|
-
i12 = 0;
|
837
|
-
if (++i13 == ne3) {
|
838
|
-
i13 = 0;
|
839
|
-
}
|
840
|
-
}
|
841
|
-
}
|
842
|
-
}
|
843
|
-
}
|
844
|
-
}
|
845
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
846
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
847
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
848
|
-
i10 += ne00 * ir0;
|
849
|
-
while (i10 >= ne0) {
|
850
|
-
i10 -= ne0;
|
851
|
-
if (++i11 == ne1) {
|
852
|
-
i11 = 0;
|
853
|
-
if (++i12 == ne2) {
|
854
|
-
i12 = 0;
|
855
|
-
if (++i13 == ne3) {
|
856
|
-
i13 = 0;
|
857
|
-
}
|
858
|
-
}
|
859
|
-
}
|
860
|
-
}
|
861
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
862
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
863
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
864
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
865
|
-
|
866
|
-
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
|
867
|
-
|
868
|
-
if (++i10 == ne0) {
|
869
|
-
i10 = 0;
|
870
|
-
if (++i11 == ne1) {
|
871
|
-
i11 = 0;
|
872
|
-
if (++i12 == ne2) {
|
873
|
-
i12 = 0;
|
874
|
-
if (++i13 == ne3) {
|
875
|
-
i13 = 0;
|
876
|
-
}
|
877
|
-
}
|
878
|
-
}
|
879
|
-
}
|
880
|
-
}
|
881
|
-
}
|
882
|
-
i10 += ne00 * (ne01 - ir1);
|
883
|
-
while (i10 >= ne0) {
|
884
|
-
i10 -= ne0;
|
885
|
-
if (++i11 == ne1) {
|
886
|
-
i11 = 0;
|
887
|
-
if (++i12 == ne2) {
|
888
|
-
i12 = 0;
|
889
|
-
if (++i13 == ne3) {
|
890
|
-
i13 = 0;
|
891
|
-
}
|
892
|
-
}
|
893
|
-
}
|
894
|
-
}
|
895
|
-
}
|
896
|
-
}
|
897
|
-
} else if (dst->type == GGML_TYPE_BF16) {
|
898
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
899
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
900
|
-
i10 += ne00 * ir0;
|
901
|
-
while (i10 >= ne0) {
|
902
|
-
i10 -= ne0;
|
903
|
-
if (++i11 == ne1) {
|
904
|
-
i11 = 0;
|
905
|
-
if (++i12 == ne2) {
|
906
|
-
i12 = 0;
|
907
|
-
if (++i13 == ne3) {
|
908
|
-
i13 = 0;
|
909
|
-
}
|
910
|
-
}
|
911
|
-
}
|
912
|
-
}
|
913
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
914
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
915
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
916
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
917
|
-
|
918
|
-
*(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
|
919
|
-
|
920
|
-
if (++i10 == ne0) {
|
921
|
-
i10 = 0;
|
922
|
-
if (++i11 == ne1) {
|
923
|
-
i11 = 0;
|
924
|
-
if (++i12 == ne2) {
|
925
|
-
i12 = 0;
|
926
|
-
if (++i13 == ne3) {
|
927
|
-
i13 = 0;
|
928
|
-
}
|
929
|
-
}
|
930
|
-
}
|
931
|
-
}
|
932
|
-
}
|
933
|
-
}
|
934
|
-
i10 += ne00 * (ne01 - ir1);
|
935
|
-
while (i10 >= ne0) {
|
936
|
-
i10 -= ne0;
|
937
|
-
if (++i11 == ne1) {
|
938
|
-
i11 = 0;
|
939
|
-
if (++i12 == ne2) {
|
940
|
-
i12 = 0;
|
941
|
-
if (++i13 == ne3) {
|
942
|
-
i13 = 0;
|
943
|
-
}
|
944
|
-
}
|
945
|
-
}
|
946
|
-
}
|
947
|
-
}
|
948
|
-
}
|
949
|
-
} else {
|
950
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
951
|
-
}
|
952
|
-
}
|
953
|
-
|
954
|
-
// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
|
955
|
-
static void ggml_compute_forward_dup_bytes(
|
956
|
-
const ggml_compute_params * params,
|
957
|
-
ggml_tensor * dst) {
|
958
|
-
const ggml_tensor * src0 = dst->src[0];
|
959
|
-
|
960
|
-
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
961
|
-
GGML_ASSERT(src0->type == dst->type);
|
962
|
-
|
963
|
-
GGML_TENSOR_UNARY_OP_LOCALS;
|
964
|
-
|
965
|
-
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
|
966
|
-
ggml_compute_forward_dup_same_cont(params, dst);
|
967
|
-
return;
|
968
|
-
}
|
969
|
-
|
970
|
-
const size_t type_size = ggml_type_size(src0->type);
|
971
|
-
|
972
|
-
const int ith = params->ith; // thread index
|
973
|
-
const int nth = params->nth; // number of threads
|
974
|
-
|
975
|
-
// parallelize by rows
|
976
|
-
const int nr = ne01;
|
977
|
-
// number of rows per thread
|
978
|
-
const int dr = (nr + nth - 1) / nth;
|
979
|
-
// row range for this thread
|
980
|
-
const int ir0 = dr * ith;
|
981
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
982
|
-
|
983
|
-
if (src0->type == dst->type &&
|
984
|
-
ggml_are_same_shape(src0, dst) &&
|
985
|
-
nb00 == type_size && nb0 == type_size) {
|
986
|
-
// copy by rows
|
987
|
-
const size_t rs = ggml_row_size(src0->type, ne00);
|
988
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
989
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
990
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
991
|
-
memcpy(
|
992
|
-
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
993
|
-
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
994
|
-
rs);
|
995
|
-
}
|
996
|
-
}
|
997
|
-
}
|
998
|
-
return;
|
999
|
-
}
|
1000
|
-
|
1001
|
-
if (ggml_is_contiguous(dst)) {
|
1002
|
-
size_t id = 0;
|
1003
|
-
char * dst_ptr = (char *) dst->data;
|
1004
|
-
const size_t rs = ne00 * type_size;
|
1005
|
-
|
1006
|
-
if (nb00 == type_size) {
|
1007
|
-
// src0 is contigous on first dimension, copy by rows
|
1008
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
1009
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
1010
|
-
id += rs * ir0;
|
1011
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
1012
|
-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
1013
|
-
memcpy(dst_ptr + id, src0_ptr, rs);
|
1014
|
-
id += rs;
|
1015
|
-
}
|
1016
|
-
id += rs * (ne01 - ir1);
|
1017
|
-
}
|
1018
|
-
}
|
1019
|
-
} else {
|
1020
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
1021
|
-
|
1022
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
1023
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
1024
|
-
id += rs * ir0;
|
1025
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
1026
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
1027
|
-
const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
|
1028
|
-
memcpy(dst_ptr + id, src0_ptr, type_size);
|
1029
|
-
|
1030
|
-
id += type_size;
|
1031
|
-
}
|
1032
|
-
}
|
1033
|
-
id += rs * (ne01 - ir1);
|
1034
|
-
}
|
1035
|
-
}
|
1036
|
-
}
|
1037
|
-
|
1038
|
-
return;
|
1039
|
-
}
|
1040
|
-
|
1041
|
-
// dst counters
|
1042
|
-
int64_t k10 = 0;
|
1043
|
-
int64_t i11 = 0;
|
1044
|
-
int64_t i12 = 0;
|
1045
|
-
int64_t i13 = 0;
|
1046
|
-
|
1047
|
-
// number of blocks in a row
|
1048
|
-
const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
|
1049
|
-
const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
|
1050
|
-
|
1051
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
1052
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
1053
|
-
k10 += nk00 * ir0;
|
1054
|
-
while (k10 >= nk0) {
|
1055
|
-
k10 -= nk0;
|
1056
|
-
if (++i11 == ne1) {
|
1057
|
-
i11 = 0;
|
1058
|
-
if (++i12 == ne2) {
|
1059
|
-
i12 = 0;
|
1060
|
-
if (++i13 == ne3) {
|
1061
|
-
i13 = 0;
|
1062
|
-
}
|
1063
|
-
}
|
1064
|
-
}
|
1065
|
-
}
|
1066
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
1067
|
-
for (int64_t k00 = 0; k00 < nk00; k00++) {
|
1068
|
-
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
1069
|
-
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
1070
|
-
|
1071
|
-
memcpy(dst_ptr, src0_ptr, type_size);
|
1072
|
-
|
1073
|
-
if (++k10 == nk0) {
|
1074
|
-
k10 = 0;
|
1075
|
-
if (++i11 == ne1) {
|
1076
|
-
i11 = 0;
|
1077
|
-
if (++i12 == ne2) {
|
1078
|
-
i12 = 0;
|
1079
|
-
if (++i13 == ne3) {
|
1080
|
-
i13 = 0;
|
1081
|
-
}
|
1082
|
-
}
|
1083
|
-
}
|
1084
|
-
}
|
1085
|
-
}
|
1086
|
-
}
|
1087
|
-
k10 += nk00 * (ne01 - ir1);
|
1088
|
-
while (k10 >= nk0) {
|
1089
|
-
k10 -= nk0;
|
1090
|
-
if (++i11 == ne1) {
|
1091
|
-
i11 = 0;
|
1092
|
-
if (++i12 == ne2) {
|
1093
|
-
i12 = 0;
|
1094
|
-
if (++i13 == ne3) {
|
1095
|
-
i13 = 0;
|
1096
|
-
}
|
1097
|
-
}
|
1098
|
-
}
|
1099
|
-
}
|
1100
|
-
}
|
1101
|
-
}
|
1102
|
-
}
|
1103
|
-
|
1104
|
-
static void ggml_compute_forward_dup_q(
|
473
|
+
static void ggml_compute_forward_dup_from_q(
|
1105
474
|
const ggml_compute_params * params,
|
1106
475
|
ggml_tensor * dst) {
|
1107
476
|
|
@@ -1166,20 +535,35 @@ void ggml_compute_forward_dup(
|
|
1166
535
|
switch (src0->type) {
|
1167
536
|
case GGML_TYPE_F16:
|
1168
537
|
{
|
1169
|
-
|
538
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
|
539
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
|
540
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
|
541
|
+
else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
|
1170
542
|
} break;
|
1171
543
|
case GGML_TYPE_BF16:
|
1172
544
|
{
|
1173
|
-
|
545
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
|
546
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
|
547
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
|
548
|
+
else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
|
1174
549
|
} break;
|
1175
550
|
case GGML_TYPE_F32:
|
1176
551
|
{
|
1177
|
-
|
552
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
|
553
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
|
554
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
|
555
|
+
else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
|
556
|
+
else ggml_compute_forward_dup_to_q<float>(params, dst);
|
557
|
+
} break;
|
558
|
+
case GGML_TYPE_I32:
|
559
|
+
{
|
560
|
+
if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
|
561
|
+
else GGML_ABORT("not implemented");
|
1178
562
|
} break;
|
1179
563
|
default:
|
1180
564
|
{
|
1181
565
|
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
|
1182
|
-
|
566
|
+
ggml_compute_forward_dup_from_q(params, dst);
|
1183
567
|
break;
|
1184
568
|
}
|
1185
569
|
GGML_ABORT("fatal error");
|
@@ -1252,20 +636,118 @@ static void ggml_compute_forward_add_q_f32(
|
|
1252
636
|
|
1253
637
|
assert(ne00 % 32 == 0);
|
1254
638
|
|
1255
|
-
// unquantize row from src0 to temp buffer
|
1256
|
-
dequantize_row_q(src0_row, wdata, ne00);
|
1257
|
-
// add src1
|
1258
|
-
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
1259
|
-
// quantize row to dst
|
1260
|
-
if (quantize_row_q != NULL) {
|
1261
|
-
quantize_row_q(wdata, dst_row, ne00);
|
1262
|
-
} else {
|
1263
|
-
memcpy(dst_row, wdata, ne0*nb0);
|
1264
|
-
}
|
639
|
+
// unquantize row from src0 to temp buffer
|
640
|
+
dequantize_row_q(src0_row, wdata, ne00);
|
641
|
+
// add src1
|
642
|
+
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
643
|
+
// quantize row to dst
|
644
|
+
if (quantize_row_q != NULL) {
|
645
|
+
quantize_row_q(wdata, dst_row, ne00);
|
646
|
+
} else {
|
647
|
+
memcpy(dst_row, wdata, ne0*nb0);
|
648
|
+
}
|
649
|
+
}
|
650
|
+
}
|
651
|
+
|
652
|
+
void ggml_compute_forward_add(
|
653
|
+
const ggml_compute_params * params,
|
654
|
+
ggml_tensor * dst) {
|
655
|
+
|
656
|
+
const ggml_tensor * src0 = dst->src[0];
|
657
|
+
|
658
|
+
switch (src0->type) {
|
659
|
+
case GGML_TYPE_F32:
|
660
|
+
case GGML_TYPE_F16:
|
661
|
+
case GGML_TYPE_BF16:
|
662
|
+
{
|
663
|
+
ggml_compute_forward_add_non_quantized(params, dst);
|
664
|
+
} break;
|
665
|
+
case GGML_TYPE_Q4_0:
|
666
|
+
case GGML_TYPE_Q4_1:
|
667
|
+
case GGML_TYPE_Q5_0:
|
668
|
+
case GGML_TYPE_Q5_1:
|
669
|
+
case GGML_TYPE_Q8_0:
|
670
|
+
case GGML_TYPE_MXFP4:
|
671
|
+
case GGML_TYPE_Q2_K:
|
672
|
+
case GGML_TYPE_Q3_K:
|
673
|
+
case GGML_TYPE_Q4_K:
|
674
|
+
case GGML_TYPE_Q5_K:
|
675
|
+
case GGML_TYPE_Q6_K:
|
676
|
+
case GGML_TYPE_TQ1_0:
|
677
|
+
case GGML_TYPE_TQ2_0:
|
678
|
+
case GGML_TYPE_IQ2_XXS:
|
679
|
+
case GGML_TYPE_IQ2_XS:
|
680
|
+
case GGML_TYPE_IQ3_XXS:
|
681
|
+
case GGML_TYPE_IQ1_S:
|
682
|
+
case GGML_TYPE_IQ1_M:
|
683
|
+
case GGML_TYPE_IQ4_NL:
|
684
|
+
case GGML_TYPE_IQ4_XS:
|
685
|
+
case GGML_TYPE_IQ3_S:
|
686
|
+
case GGML_TYPE_IQ2_S:
|
687
|
+
{
|
688
|
+
ggml_compute_forward_add_q_f32(params, dst);
|
689
|
+
} break;
|
690
|
+
default:
|
691
|
+
{
|
692
|
+
GGML_ABORT("fatal error");
|
693
|
+
}
|
694
|
+
}
|
695
|
+
}
|
696
|
+
|
697
|
+
// ggml_compute_forward_add_id
|
698
|
+
|
699
|
+
static void ggml_compute_forward_add_id_f32(
|
700
|
+
const ggml_compute_params * params,
|
701
|
+
ggml_tensor * dst) {
|
702
|
+
|
703
|
+
const ggml_tensor * src0 = dst->src[0];
|
704
|
+
const ggml_tensor * src1 = dst->src[1];
|
705
|
+
const ggml_tensor * src2 = dst->src[2];
|
706
|
+
|
707
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
708
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
709
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
710
|
+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
711
|
+
|
712
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
713
|
+
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
714
|
+
|
715
|
+
const int ith = params->ith;
|
716
|
+
const int nth = params->nth;
|
717
|
+
|
718
|
+
const int nr = ggml_nrows(src0);
|
719
|
+
|
720
|
+
GGML_TENSOR_TERNARY_OP_LOCALS
|
721
|
+
|
722
|
+
GGML_ASSERT( nb0 == sizeof(float));
|
723
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
724
|
+
|
725
|
+
// rows per thread
|
726
|
+
const int dr = (nr + nth - 1)/nth;
|
727
|
+
|
728
|
+
// row range for this thread
|
729
|
+
const int ir0 = dr*ith;
|
730
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
731
|
+
|
732
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
733
|
+
// src0 indices
|
734
|
+
const int i3 = ir/(ne2*ne1);
|
735
|
+
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
736
|
+
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
737
|
+
|
738
|
+
// src1 indices
|
739
|
+
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
|
740
|
+
|
741
|
+
GGML_ASSERT(i11 >= 0 && i11 < ne11);
|
742
|
+
|
743
|
+
ggml_vec_add_f32(ne0,
|
744
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
745
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
746
|
+
(float *) ((char *) src1->data + i11*nb11));
|
1265
747
|
}
|
1266
748
|
}
|
1267
749
|
|
1268
|
-
void
|
750
|
+
void ggml_compute_forward_add_id(
|
1269
751
|
const ggml_compute_params * params,
|
1270
752
|
ggml_tensor * dst) {
|
1271
753
|
|
@@ -1273,38 +755,12 @@ void ggml_compute_forward_add(
|
|
1273
755
|
|
1274
756
|
switch (src0->type) {
|
1275
757
|
case GGML_TYPE_F32:
|
1276
|
-
case GGML_TYPE_F16:
|
1277
|
-
case GGML_TYPE_BF16:
|
1278
|
-
{
|
1279
|
-
ggml_compute_forward_add_non_quantized(params, dst);
|
1280
|
-
} break;
|
1281
|
-
case GGML_TYPE_Q4_0:
|
1282
|
-
case GGML_TYPE_Q4_1:
|
1283
|
-
case GGML_TYPE_Q5_0:
|
1284
|
-
case GGML_TYPE_Q5_1:
|
1285
|
-
case GGML_TYPE_Q8_0:
|
1286
|
-
case GGML_TYPE_Q2_K:
|
1287
|
-
case GGML_TYPE_Q3_K:
|
1288
|
-
case GGML_TYPE_Q4_K:
|
1289
|
-
case GGML_TYPE_Q5_K:
|
1290
|
-
case GGML_TYPE_Q6_K:
|
1291
|
-
case GGML_TYPE_TQ1_0:
|
1292
|
-
case GGML_TYPE_TQ2_0:
|
1293
|
-
case GGML_TYPE_IQ2_XXS:
|
1294
|
-
case GGML_TYPE_IQ2_XS:
|
1295
|
-
case GGML_TYPE_IQ3_XXS:
|
1296
|
-
case GGML_TYPE_IQ1_S:
|
1297
|
-
case GGML_TYPE_IQ1_M:
|
1298
|
-
case GGML_TYPE_IQ4_NL:
|
1299
|
-
case GGML_TYPE_IQ4_XS:
|
1300
|
-
case GGML_TYPE_IQ3_S:
|
1301
|
-
case GGML_TYPE_IQ2_S:
|
1302
758
|
{
|
1303
|
-
|
759
|
+
ggml_compute_forward_add_id_f32(params, dst);
|
1304
760
|
} break;
|
1305
761
|
default:
|
1306
762
|
{
|
1307
|
-
GGML_ABORT("
|
763
|
+
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
|
1308
764
|
}
|
1309
765
|
}
|
1310
766
|
}
|
@@ -1660,6 +1116,7 @@ void ggml_compute_forward_add1(
|
|
1660
1116
|
case GGML_TYPE_Q5_1:
|
1661
1117
|
case GGML_TYPE_Q8_0:
|
1662
1118
|
case GGML_TYPE_Q8_1:
|
1119
|
+
case GGML_TYPE_MXFP4:
|
1663
1120
|
case GGML_TYPE_Q2_K:
|
1664
1121
|
case GGML_TYPE_Q3_K:
|
1665
1122
|
case GGML_TYPE_Q4_K:
|
@@ -1787,6 +1244,7 @@ void ggml_compute_forward_acc(
|
|
1787
1244
|
case GGML_TYPE_Q5_1:
|
1788
1245
|
case GGML_TYPE_Q8_0:
|
1789
1246
|
case GGML_TYPE_Q8_1:
|
1247
|
+
case GGML_TYPE_MXFP4:
|
1790
1248
|
case GGML_TYPE_Q2_K:
|
1791
1249
|
case GGML_TYPE_Q3_K:
|
1792
1250
|
case GGML_TYPE_Q4_K:
|
@@ -3009,50 +2467,304 @@ static void ggml_compute_forward_leaky_relu_f32(
|
|
3009
2467
|
const int n = ggml_nrows(src0);
|
3010
2468
|
const int nc = src0->ne[0];
|
3011
2469
|
|
3012
|
-
float negative_slope;
|
3013
|
-
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
2470
|
+
float negative_slope;
|
2471
|
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
2472
|
+
|
2473
|
+
assert(dst->nb[0] == sizeof(float));
|
2474
|
+
assert(src0->nb[0] == sizeof(float));
|
2475
|
+
|
2476
|
+
for (int i = 0; i < n; i++) {
|
2477
|
+
ggml_vec_leaky_relu_f32(nc,
|
2478
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
2479
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
|
2480
|
+
}
|
2481
|
+
}
|
2482
|
+
|
2483
|
+
static void ggml_compute_forward_leaky_relu_f16(
|
2484
|
+
const ggml_compute_params * params,
|
2485
|
+
ggml_tensor * dst) {
|
2486
|
+
|
2487
|
+
const ggml_tensor * src0 = dst->src[0];
|
2488
|
+
|
2489
|
+
if (params->ith != 0) {
|
2490
|
+
return;
|
2491
|
+
}
|
2492
|
+
|
2493
|
+
assert(ggml_is_contiguous_1(src0));
|
2494
|
+
assert(ggml_is_contiguous_1(dst));
|
2495
|
+
assert(ggml_are_same_shape(src0, dst));
|
2496
|
+
|
2497
|
+
const int n = ggml_nrows(src0);
|
2498
|
+
const int nc = src0->ne[0];
|
2499
|
+
|
2500
|
+
float negative_slope;
|
2501
|
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
2502
|
+
|
2503
|
+
assert(dst->nb[0] == sizeof(ggml_fp16_t));
|
2504
|
+
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
2505
|
+
|
2506
|
+
for (int i = 0; i < n; i++) {
|
2507
|
+
ggml_vec_leaky_relu_f16(nc,
|
2508
|
+
(ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
|
2509
|
+
(ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
|
2510
|
+
}
|
2511
|
+
}
|
2512
|
+
|
2513
|
+
void ggml_compute_forward_leaky_relu(
|
2514
|
+
const ggml_compute_params * params,
|
2515
|
+
ggml_tensor * dst) {
|
2516
|
+
|
2517
|
+
const ggml_tensor * src0 = dst->src[0];
|
2518
|
+
|
2519
|
+
switch (src0->type) {
|
2520
|
+
case GGML_TYPE_F32:
|
2521
|
+
{
|
2522
|
+
ggml_compute_forward_leaky_relu_f32(params, dst);
|
2523
|
+
} break;
|
2524
|
+
case GGML_TYPE_F16:
|
2525
|
+
{
|
2526
|
+
ggml_compute_forward_leaky_relu_f16(params, dst);
|
2527
|
+
} break;
|
2528
|
+
default:
|
2529
|
+
{
|
2530
|
+
GGML_ABORT("fatal error");
|
2531
|
+
}
|
2532
|
+
}
|
2533
|
+
}
|
2534
|
+
|
2535
|
+
// ggml_compute_forward_silu_back
|
2536
|
+
|
2537
|
+
static void ggml_compute_forward_silu_back_f32(
|
2538
|
+
const ggml_compute_params * params,
|
2539
|
+
ggml_tensor * dst) {
|
2540
|
+
|
2541
|
+
const ggml_tensor * grad = dst->src[0];
|
2542
|
+
const ggml_tensor * src1 = dst->src[1];
|
2543
|
+
|
2544
|
+
assert(ggml_is_contiguous_1(grad));
|
2545
|
+
assert(ggml_is_contiguous_1(src1));
|
2546
|
+
assert(ggml_is_contiguous_1(dst));
|
2547
|
+
assert(ggml_are_same_shape(src1, dst));
|
2548
|
+
assert(ggml_are_same_shape(src1, grad));
|
2549
|
+
|
2550
|
+
const int ith = params->ith;
|
2551
|
+
const int nth = params->nth;
|
2552
|
+
|
2553
|
+
const int nc = src1->ne[0];
|
2554
|
+
const int nr = ggml_nrows(src1);
|
2555
|
+
|
2556
|
+
// rows per thread
|
2557
|
+
const int dr = (nr + nth - 1)/nth;
|
2558
|
+
|
2559
|
+
// row range for this thread
|
2560
|
+
const int ir0 = dr*ith;
|
2561
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
2562
|
+
|
2563
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
2564
|
+
ggml_vec_silu_backward_f32(nc,
|
2565
|
+
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
2566
|
+
(float *) ((char *) src1->data + i1*(src1->nb[1])),
|
2567
|
+
(float *) ((char *) grad->data + i1*(grad->nb[1])));
|
2568
|
+
|
2569
|
+
#ifndef NDEBUG
|
2570
|
+
for (int k = 0; k < nc; k++) {
|
2571
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2572
|
+
GGML_UNUSED(x);
|
2573
|
+
assert(!isnan(x));
|
2574
|
+
assert(!isinf(x));
|
2575
|
+
}
|
2576
|
+
#endif
|
2577
|
+
}
|
2578
|
+
}
|
2579
|
+
|
2580
|
+
static void ggml_compute_forward_silu_back_f16(
|
2581
|
+
const ggml_compute_params * params,
|
2582
|
+
ggml_tensor * dst) {
|
2583
|
+
|
2584
|
+
const ggml_tensor * grad = dst->src[0];
|
2585
|
+
const ggml_tensor * src1 = dst->src[1];
|
2586
|
+
|
2587
|
+
assert(ggml_is_contiguous_1(grad));
|
2588
|
+
assert(ggml_is_contiguous_1(src1));
|
2589
|
+
assert(ggml_is_contiguous_1(dst));
|
2590
|
+
assert(ggml_are_same_shape(src1, dst));
|
2591
|
+
assert(ggml_are_same_shape(src1, grad));
|
2592
|
+
|
2593
|
+
const int ith = params->ith;
|
2594
|
+
const int nth = params->nth;
|
2595
|
+
|
2596
|
+
const int nc = src1->ne[0];
|
2597
|
+
const int nr = ggml_nrows(src1);
|
2598
|
+
|
2599
|
+
// rows per thread
|
2600
|
+
const int dr = (nr + nth - 1)/nth;
|
2601
|
+
|
2602
|
+
// row range for this thread
|
2603
|
+
const int ir0 = dr*ith;
|
2604
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
2605
|
+
|
2606
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
2607
|
+
ggml_vec_silu_backward_f16(nc,
|
2608
|
+
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
2609
|
+
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
2610
|
+
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
2611
|
+
|
2612
|
+
#ifndef NDEBUG
|
2613
|
+
for (int k = 0; k < nc; k++) {
|
2614
|
+
const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2615
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
2616
|
+
GGML_UNUSED(v);
|
2617
|
+
assert(!isnan(v));
|
2618
|
+
assert(!isinf(v));
|
2619
|
+
}
|
2620
|
+
#endif
|
2621
|
+
}
|
2622
|
+
}
|
2623
|
+
|
2624
|
+
void ggml_compute_forward_silu_back(
|
2625
|
+
const ggml_compute_params * params,
|
2626
|
+
ggml_tensor * dst) {
|
2627
|
+
|
2628
|
+
const ggml_tensor * src0 = dst->src[0];
|
2629
|
+
|
2630
|
+
switch (src0->type) {
|
2631
|
+
case GGML_TYPE_F32:
|
2632
|
+
{
|
2633
|
+
ggml_compute_forward_silu_back_f32(params, dst);
|
2634
|
+
} break;
|
2635
|
+
case GGML_TYPE_F16:
|
2636
|
+
{
|
2637
|
+
ggml_compute_forward_silu_back_f16(params, dst);
|
2638
|
+
} break;
|
2639
|
+
default:
|
2640
|
+
{
|
2641
|
+
GGML_ABORT("fatal error");
|
2642
|
+
}
|
2643
|
+
}
|
2644
|
+
}
|
2645
|
+
|
2646
|
+
// ggml_compute_forward_reglu
|
2647
|
+
|
2648
|
+
static void ggml_compute_forward_reglu_f32(
|
2649
|
+
const ggml_compute_params * params,
|
2650
|
+
ggml_tensor * dst) {
|
2651
|
+
|
2652
|
+
const ggml_tensor * src0 = dst->src[0];
|
2653
|
+
const ggml_tensor * src1 = dst->src[1];
|
2654
|
+
char * src0_d = (char *) src0->data;
|
2655
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
2656
|
+
const size_t src0_o = src0->nb[1];
|
2657
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
2658
|
+
|
2659
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2660
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
2661
|
+
|
2662
|
+
if (src1) {
|
2663
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
2664
|
+
GGML_ASSERT(src0->type == src1->type);
|
2665
|
+
}
|
2666
|
+
|
2667
|
+
const int ith = params->ith;
|
2668
|
+
const int nth = params->nth;
|
2669
|
+
|
2670
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
2671
|
+
const int nr = ggml_nrows(src0);
|
2672
|
+
|
2673
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
2674
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
2675
|
+
|
2676
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
2677
|
+
|
2678
|
+
// rows per thread
|
2679
|
+
const int dr = (nr + nth - 1)/nth;
|
2680
|
+
|
2681
|
+
// row range for this thread
|
2682
|
+
const int ir0 = dr*ith;
|
2683
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
2684
|
+
|
2685
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
2686
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
2687
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
3014
2688
|
|
3015
|
-
|
3016
|
-
|
2689
|
+
if (!src1) {
|
2690
|
+
src0_p += swapped ? nc : 0;
|
2691
|
+
src1_p += swapped ? 0 : nc;
|
2692
|
+
}
|
3017
2693
|
|
3018
|
-
|
3019
|
-
|
3020
|
-
|
3021
|
-
|
2694
|
+
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
2695
|
+
|
2696
|
+
#ifndef NDEBUG
|
2697
|
+
for (int k = 0; k < nc; k++) {
|
2698
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2699
|
+
GGML_UNUSED(x);
|
2700
|
+
assert(!isnan(x));
|
2701
|
+
assert(!isinf(x));
|
2702
|
+
}
|
2703
|
+
#endif
|
3022
2704
|
}
|
3023
2705
|
}
|
3024
2706
|
|
3025
|
-
static void
|
2707
|
+
static void ggml_compute_forward_reglu_f16(
|
3026
2708
|
const ggml_compute_params * params,
|
3027
2709
|
ggml_tensor * dst) {
|
3028
2710
|
|
3029
2711
|
const ggml_tensor * src0 = dst->src[0];
|
2712
|
+
const ggml_tensor * src1 = dst->src[1];
|
2713
|
+
char * src0_d = (char *) src0->data;
|
2714
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
2715
|
+
const size_t src0_o = src0->nb[1];
|
2716
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3030
2717
|
|
3031
|
-
|
3032
|
-
|
2718
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2719
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
2720
|
+
|
2721
|
+
if (src1) {
|
2722
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
2723
|
+
GGML_ASSERT(src0->type == src1->type);
|
3033
2724
|
}
|
3034
2725
|
|
3035
|
-
|
3036
|
-
|
3037
|
-
assert(ggml_are_same_shape(src0, dst));
|
2726
|
+
const int ith = params->ith;
|
2727
|
+
const int nth = params->nth;
|
3038
2728
|
|
3039
|
-
const int
|
3040
|
-
const int
|
2729
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
2730
|
+
const int nr = ggml_nrows(src0);
|
3041
2731
|
|
3042
|
-
|
3043
|
-
|
2732
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
2733
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3044
2734
|
|
3045
|
-
|
3046
|
-
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
2735
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3047
2736
|
|
3048
|
-
|
3049
|
-
|
3050
|
-
|
3051
|
-
|
2737
|
+
// rows per thread
|
2738
|
+
const int dr = (nr + nth - 1)/nth;
|
2739
|
+
|
2740
|
+
// row range for this thread
|
2741
|
+
const int ir0 = dr*ith;
|
2742
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
2743
|
+
|
2744
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
2745
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
2746
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
2747
|
+
|
2748
|
+
if (!src1) {
|
2749
|
+
src0_p += swapped ? nc : 0;
|
2750
|
+
src1_p += swapped ? 0 : nc;
|
2751
|
+
}
|
2752
|
+
|
2753
|
+
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
2754
|
+
|
2755
|
+
#ifndef NDEBUG
|
2756
|
+
for (int k = 0; k < nc; k++) {
|
2757
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2758
|
+
const float v = GGML_FP16_TO_FP32(x);
|
2759
|
+
GGML_UNUSED(v);
|
2760
|
+
assert(!isnan(v));
|
2761
|
+
assert(!isinf(v));
|
2762
|
+
}
|
2763
|
+
#endif
|
3052
2764
|
}
|
3053
2765
|
}
|
3054
2766
|
|
3055
|
-
void
|
2767
|
+
static void ggml_compute_forward_reglu(
|
3056
2768
|
const ggml_compute_params * params,
|
3057
2769
|
ggml_tensor * dst) {
|
3058
2770
|
|
@@ -3061,11 +2773,11 @@ void ggml_compute_forward_leaky_relu(
|
|
3061
2773
|
switch (src0->type) {
|
3062
2774
|
case GGML_TYPE_F32:
|
3063
2775
|
{
|
3064
|
-
|
2776
|
+
ggml_compute_forward_reglu_f32(params, dst);
|
3065
2777
|
} break;
|
3066
2778
|
case GGML_TYPE_F16:
|
3067
2779
|
{
|
3068
|
-
|
2780
|
+
ggml_compute_forward_reglu_f16(params, dst);
|
3069
2781
|
} break;
|
3070
2782
|
default:
|
3071
2783
|
{
|
@@ -3074,26 +2786,37 @@ void ggml_compute_forward_leaky_relu(
|
|
3074
2786
|
}
|
3075
2787
|
}
|
3076
2788
|
|
3077
|
-
//
|
2789
|
+
// ggml_compute_forward_geglu
|
3078
2790
|
|
3079
|
-
static void
|
2791
|
+
static void ggml_compute_forward_geglu_f32(
|
3080
2792
|
const ggml_compute_params * params,
|
3081
2793
|
ggml_tensor * dst) {
|
3082
2794
|
|
3083
|
-
const ggml_tensor *
|
2795
|
+
const ggml_tensor * src0 = dst->src[0];
|
3084
2796
|
const ggml_tensor * src1 = dst->src[1];
|
2797
|
+
char * src0_d = (char *) src0->data;
|
2798
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
2799
|
+
const size_t src0_o = src0->nb[1];
|
2800
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3085
2801
|
|
3086
|
-
|
3087
|
-
|
3088
|
-
|
3089
|
-
|
3090
|
-
|
2802
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2803
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
2804
|
+
|
2805
|
+
if (src1) {
|
2806
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
2807
|
+
GGML_ASSERT(src0->type == src1->type);
|
2808
|
+
}
|
3091
2809
|
|
3092
2810
|
const int ith = params->ith;
|
3093
2811
|
const int nth = params->nth;
|
3094
2812
|
|
3095
|
-
const int nc = src1->ne[0];
|
3096
|
-
const int nr = ggml_nrows(
|
2813
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
2814
|
+
const int nr = ggml_nrows(src0);
|
2815
|
+
|
2816
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
2817
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
2818
|
+
|
2819
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3097
2820
|
|
3098
2821
|
// rows per thread
|
3099
2822
|
const int dr = (nr + nth - 1)/nth;
|
@@ -3103,10 +2826,15 @@ static void ggml_compute_forward_silu_back_f32(
|
|
3103
2826
|
const int ir1 = MIN(ir0 + dr, nr);
|
3104
2827
|
|
3105
2828
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
3106
|
-
|
3107
|
-
|
3108
|
-
|
3109
|
-
|
2829
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
2830
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
2831
|
+
|
2832
|
+
if (!src1) {
|
2833
|
+
src0_p += swapped ? nc : 0;
|
2834
|
+
src1_p += swapped ? 0 : nc;
|
2835
|
+
}
|
2836
|
+
|
2837
|
+
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3110
2838
|
|
3111
2839
|
#ifndef NDEBUG
|
3112
2840
|
for (int k = 0; k < nc; k++) {
|
@@ -3119,24 +2847,35 @@ static void ggml_compute_forward_silu_back_f32(
|
|
3119
2847
|
}
|
3120
2848
|
}
|
3121
2849
|
|
3122
|
-
static void
|
2850
|
+
static void ggml_compute_forward_geglu_f16(
|
3123
2851
|
const ggml_compute_params * params,
|
3124
2852
|
ggml_tensor * dst) {
|
3125
2853
|
|
3126
|
-
const ggml_tensor *
|
2854
|
+
const ggml_tensor * src0 = dst->src[0];
|
3127
2855
|
const ggml_tensor * src1 = dst->src[1];
|
2856
|
+
char * src0_d = (char *) src0->data;
|
2857
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
2858
|
+
const size_t src0_o = src0->nb[1];
|
2859
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3128
2860
|
|
3129
|
-
|
3130
|
-
|
3131
|
-
|
3132
|
-
|
3133
|
-
|
2861
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2862
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
2863
|
+
|
2864
|
+
if (src1) {
|
2865
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
2866
|
+
GGML_ASSERT(src0->type == src1->type);
|
2867
|
+
}
|
3134
2868
|
|
3135
2869
|
const int ith = params->ith;
|
3136
2870
|
const int nth = params->nth;
|
3137
2871
|
|
3138
|
-
const int nc = src1->ne[0];
|
3139
|
-
const int nr = ggml_nrows(
|
2872
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
2873
|
+
const int nr = ggml_nrows(src0);
|
2874
|
+
|
2875
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
2876
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
2877
|
+
|
2878
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3140
2879
|
|
3141
2880
|
// rows per thread
|
3142
2881
|
const int dr = (nr + nth - 1)/nth;
|
@@ -3146,24 +2885,29 @@ static void ggml_compute_forward_silu_back_f16(
|
|
3146
2885
|
const int ir1 = MIN(ir0 + dr, nr);
|
3147
2886
|
|
3148
2887
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
3149
|
-
|
3150
|
-
|
3151
|
-
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
3152
|
-
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
2888
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
2889
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
3153
2890
|
|
3154
|
-
|
2891
|
+
if (!src1) {
|
2892
|
+
src0_p += swapped ? nc : 0;
|
2893
|
+
src1_p += swapped ? 0 : nc;
|
2894
|
+
}
|
2895
|
+
|
2896
|
+
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
2897
|
+
|
2898
|
+
#ifndef NDEBUG
|
3155
2899
|
for (int k = 0; k < nc; k++) {
|
3156
|
-
const
|
3157
|
-
const float v =
|
2900
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
2901
|
+
const float v = GGML_FP16_TO_FP32(x);
|
3158
2902
|
GGML_UNUSED(v);
|
3159
2903
|
assert(!isnan(v));
|
3160
2904
|
assert(!isinf(v));
|
3161
2905
|
}
|
3162
|
-
|
2906
|
+
#endif
|
3163
2907
|
}
|
3164
2908
|
}
|
3165
2909
|
|
3166
|
-
void
|
2910
|
+
static void ggml_compute_forward_geglu(
|
3167
2911
|
const ggml_compute_params * params,
|
3168
2912
|
ggml_tensor * dst) {
|
3169
2913
|
|
@@ -3172,11 +2916,11 @@ void ggml_compute_forward_silu_back(
|
|
3172
2916
|
switch (src0->type) {
|
3173
2917
|
case GGML_TYPE_F32:
|
3174
2918
|
{
|
3175
|
-
|
2919
|
+
ggml_compute_forward_geglu_f32(params, dst);
|
3176
2920
|
} break;
|
3177
2921
|
case GGML_TYPE_F16:
|
3178
2922
|
{
|
3179
|
-
|
2923
|
+
ggml_compute_forward_geglu_f16(params, dst);
|
3180
2924
|
} break;
|
3181
2925
|
default:
|
3182
2926
|
{
|
@@ -3185,9 +2929,9 @@ void ggml_compute_forward_silu_back(
|
|
3185
2929
|
}
|
3186
2930
|
}
|
3187
2931
|
|
3188
|
-
//
|
2932
|
+
// ggml_compute_forward_swiglu
|
3189
2933
|
|
3190
|
-
static void
|
2934
|
+
static void ggml_compute_forward_swiglu_f32(
|
3191
2935
|
const ggml_compute_params * params,
|
3192
2936
|
ggml_tensor * dst) {
|
3193
2937
|
|
@@ -3233,7 +2977,7 @@ static void ggml_compute_forward_reglu_f32(
|
|
3233
2977
|
src1_p += swapped ? 0 : nc;
|
3234
2978
|
}
|
3235
2979
|
|
3236
|
-
|
2980
|
+
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3237
2981
|
|
3238
2982
|
#ifndef NDEBUG
|
3239
2983
|
for (int k = 0; k < nc; k++) {
|
@@ -3246,9 +2990,93 @@ static void ggml_compute_forward_reglu_f32(
|
|
3246
2990
|
}
|
3247
2991
|
}
|
3248
2992
|
|
3249
|
-
static void
|
3250
|
-
const ggml_compute_params * params,
|
3251
|
-
ggml_tensor * dst) {
|
2993
|
+
static void ggml_compute_forward_swiglu_f16(
|
2994
|
+
const ggml_compute_params * params,
|
2995
|
+
ggml_tensor * dst) {
|
2996
|
+
|
2997
|
+
const ggml_tensor * src0 = dst->src[0];
|
2998
|
+
const ggml_tensor * src1 = dst->src[1];
|
2999
|
+
char * src0_d = (char *) src0->data;
|
3000
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
3001
|
+
const size_t src0_o = src0->nb[1];
|
3002
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
3003
|
+
|
3004
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
3005
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
3006
|
+
|
3007
|
+
if (src1) {
|
3008
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
3009
|
+
GGML_ASSERT(src0->type == src1->type);
|
3010
|
+
}
|
3011
|
+
|
3012
|
+
const int ith = params->ith;
|
3013
|
+
const int nth = params->nth;
|
3014
|
+
|
3015
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
3016
|
+
const int nr = ggml_nrows(src0);
|
3017
|
+
|
3018
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
3019
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3020
|
+
|
3021
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3022
|
+
|
3023
|
+
// rows per thread
|
3024
|
+
const int dr = (nr + nth - 1)/nth;
|
3025
|
+
|
3026
|
+
// row range for this thread
|
3027
|
+
const int ir0 = dr*ith;
|
3028
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
3029
|
+
|
3030
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
3031
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
3032
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
3033
|
+
|
3034
|
+
if (!src1) {
|
3035
|
+
src0_p += swapped ? nc : 0;
|
3036
|
+
src1_p += swapped ? 0 : nc;
|
3037
|
+
}
|
3038
|
+
|
3039
|
+
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3040
|
+
|
3041
|
+
#ifndef NDEBUG
|
3042
|
+
for (int k = 0; k < nc; k++) {
|
3043
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
3044
|
+
const float v = GGML_FP16_TO_FP32(x);
|
3045
|
+
GGML_UNUSED(v);
|
3046
|
+
assert(!isnan(v));
|
3047
|
+
assert(!isinf(v));
|
3048
|
+
}
|
3049
|
+
#endif
|
3050
|
+
}
|
3051
|
+
}
|
3052
|
+
|
3053
|
+
static void ggml_compute_forward_swiglu(
|
3054
|
+
const ggml_compute_params * params,
|
3055
|
+
ggml_tensor * dst) {
|
3056
|
+
|
3057
|
+
const ggml_tensor * src0 = dst->src[0];
|
3058
|
+
|
3059
|
+
switch (src0->type) {
|
3060
|
+
case GGML_TYPE_F32:
|
3061
|
+
{
|
3062
|
+
ggml_compute_forward_swiglu_f32(params, dst);
|
3063
|
+
} break;
|
3064
|
+
case GGML_TYPE_F16:
|
3065
|
+
{
|
3066
|
+
ggml_compute_forward_swiglu_f16(params, dst);
|
3067
|
+
} break;
|
3068
|
+
default:
|
3069
|
+
{
|
3070
|
+
GGML_ABORT("fatal error");
|
3071
|
+
}
|
3072
|
+
}
|
3073
|
+
}
|
3074
|
+
|
3075
|
+
// ggml_compute_forward_swiglu_oai
|
3076
|
+
|
3077
|
+
static void ggml_compute_forward_swiglu_oai_f32(
|
3078
|
+
const ggml_compute_params * params,
|
3079
|
+
ggml_tensor * dst) {
|
3252
3080
|
|
3253
3081
|
const ggml_tensor * src0 = dst->src[0];
|
3254
3082
|
const ggml_tensor * src1 = dst->src[1];
|
@@ -3275,6 +3103,8 @@ static void ggml_compute_forward_reglu_f16(
|
|
3275
3103
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
3276
3104
|
|
3277
3105
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
3106
|
+
const float alpha = ggml_get_op_params_f32(dst, 2);
|
3107
|
+
const float limit = ggml_get_op_params_f32(dst, 3);
|
3278
3108
|
|
3279
3109
|
// rows per thread
|
3280
3110
|
const int dr = (nr + nth - 1)/nth;
|
@@ -3284,29 +3114,34 @@ static void ggml_compute_forward_reglu_f16(
|
|
3284
3114
|
const int ir1 = MIN(ir0 + dr, nr);
|
3285
3115
|
|
3286
3116
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
3287
|
-
|
3288
|
-
|
3117
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
3118
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
3119
|
+
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
|
3289
3120
|
|
3290
3121
|
if (!src1) {
|
3291
3122
|
src0_p += swapped ? nc : 0;
|
3292
3123
|
src1_p += swapped ? 0 : nc;
|
3293
3124
|
}
|
3294
3125
|
|
3295
|
-
|
3126
|
+
for (int k = 0; k < nc; k++) {
|
3127
|
+
const float x = std::min(src0_p[k], limit);
|
3128
|
+
const float y = std::clamp(src1_p[k], -limit, limit);
|
3129
|
+
const float out_glu = x / (1.f + expf(alpha * (-x)));
|
3130
|
+
dst_p[k] = out_glu * (y + 1.f);
|
3131
|
+
}
|
3296
3132
|
|
3297
3133
|
#ifndef NDEBUG
|
3298
3134
|
for (int k = 0; k < nc; k++) {
|
3299
|
-
const
|
3300
|
-
|
3301
|
-
|
3302
|
-
assert(!
|
3303
|
-
assert(!isinf(v));
|
3135
|
+
const float x = dst_p[k];
|
3136
|
+
GGML_UNUSED(x);
|
3137
|
+
assert(!isnan(x));
|
3138
|
+
assert(!isinf(x));
|
3304
3139
|
}
|
3305
3140
|
#endif
|
3306
3141
|
}
|
3307
3142
|
}
|
3308
3143
|
|
3309
|
-
static void
|
3144
|
+
static void ggml_compute_forward_swiglu_oai(
|
3310
3145
|
const ggml_compute_params * params,
|
3311
3146
|
ggml_tensor * dst) {
|
3312
3147
|
|
@@ -3315,11 +3150,7 @@ static void ggml_compute_forward_reglu(
|
|
3315
3150
|
switch (src0->type) {
|
3316
3151
|
case GGML_TYPE_F32:
|
3317
3152
|
{
|
3318
|
-
|
3319
|
-
} break;
|
3320
|
-
case GGML_TYPE_F16:
|
3321
|
-
{
|
3322
|
-
ggml_compute_forward_reglu_f16(params, dst);
|
3153
|
+
ggml_compute_forward_swiglu_oai_f32(params, dst);
|
3323
3154
|
} break;
|
3324
3155
|
default:
|
3325
3156
|
{
|
@@ -3328,9 +3159,9 @@ static void ggml_compute_forward_reglu(
|
|
3328
3159
|
}
|
3329
3160
|
}
|
3330
3161
|
|
3331
|
-
//
|
3162
|
+
// ggml_compute_forward_geglu_erf
|
3332
3163
|
|
3333
|
-
static void
|
3164
|
+
static void ggml_compute_forward_geglu_erf_f32(
|
3334
3165
|
const ggml_compute_params * params,
|
3335
3166
|
ggml_tensor * dst) {
|
3336
3167
|
|
@@ -3376,7 +3207,7 @@ static void ggml_compute_forward_geglu_f32(
|
|
3376
3207
|
src1_p += swapped ? 0 : nc;
|
3377
3208
|
}
|
3378
3209
|
|
3379
|
-
|
3210
|
+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3380
3211
|
|
3381
3212
|
#ifndef NDEBUG
|
3382
3213
|
for (int k = 0; k < nc; k++) {
|
@@ -3389,7 +3220,7 @@ static void ggml_compute_forward_geglu_f32(
|
|
3389
3220
|
}
|
3390
3221
|
}
|
3391
3222
|
|
3392
|
-
static void
|
3223
|
+
static void ggml_compute_forward_geglu_erf_f16(
|
3393
3224
|
const ggml_compute_params * params,
|
3394
3225
|
ggml_tensor * dst) {
|
3395
3226
|
|
@@ -3435,7 +3266,7 @@ static void ggml_compute_forward_geglu_f16(
|
|
3435
3266
|
src1_p += swapped ? 0 : nc;
|
3436
3267
|
}
|
3437
3268
|
|
3438
|
-
|
3269
|
+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3439
3270
|
|
3440
3271
|
#ifndef NDEBUG
|
3441
3272
|
for (int k = 0; k < nc; k++) {
|
@@ -3449,7 +3280,7 @@ static void ggml_compute_forward_geglu_f16(
|
|
3449
3280
|
}
|
3450
3281
|
}
|
3451
3282
|
|
3452
|
-
static void
|
3283
|
+
static void ggml_compute_forward_geglu_erf(
|
3453
3284
|
const ggml_compute_params * params,
|
3454
3285
|
ggml_tensor * dst) {
|
3455
3286
|
|
@@ -3458,11 +3289,11 @@ static void ggml_compute_forward_geglu(
|
|
3458
3289
|
switch (src0->type) {
|
3459
3290
|
case GGML_TYPE_F32:
|
3460
3291
|
{
|
3461
|
-
|
3292
|
+
ggml_compute_forward_geglu_erf_f32(params, dst);
|
3462
3293
|
} break;
|
3463
3294
|
case GGML_TYPE_F16:
|
3464
3295
|
{
|
3465
|
-
|
3296
|
+
ggml_compute_forward_geglu_erf_f16(params, dst);
|
3466
3297
|
} break;
|
3467
3298
|
default:
|
3468
3299
|
{
|
@@ -3471,9 +3302,9 @@ static void ggml_compute_forward_geglu(
|
|
3471
3302
|
}
|
3472
3303
|
}
|
3473
3304
|
|
3474
|
-
//
|
3305
|
+
// ggml_compute_forward_geglu_quick
|
3475
3306
|
|
3476
|
-
static void
|
3307
|
+
static void ggml_compute_forward_geglu_quick_f32(
|
3477
3308
|
const ggml_compute_params * params,
|
3478
3309
|
ggml_tensor * dst) {
|
3479
3310
|
|
@@ -3519,7 +3350,7 @@ static void ggml_compute_forward_swiglu_f32(
|
|
3519
3350
|
src1_p += swapped ? 0 : nc;
|
3520
3351
|
}
|
3521
3352
|
|
3522
|
-
|
3353
|
+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3523
3354
|
|
3524
3355
|
#ifndef NDEBUG
|
3525
3356
|
for (int k = 0; k < nc; k++) {
|
@@ -3532,7 +3363,7 @@ static void ggml_compute_forward_swiglu_f32(
|
|
3532
3363
|
}
|
3533
3364
|
}
|
3534
3365
|
|
3535
|
-
static void
|
3366
|
+
static void ggml_compute_forward_geglu_quick_f16(
|
3536
3367
|
const ggml_compute_params * params,
|
3537
3368
|
ggml_tensor * dst) {
|
3538
3369
|
|
@@ -3578,7 +3409,7 @@ static void ggml_compute_forward_swiglu_f16(
|
|
3578
3409
|
src1_p += swapped ? 0 : nc;
|
3579
3410
|
}
|
3580
3411
|
|
3581
|
-
|
3412
|
+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
3582
3413
|
|
3583
3414
|
#ifndef NDEBUG
|
3584
3415
|
for (int k = 0; k < nc; k++) {
|
@@ -3592,7 +3423,7 @@ static void ggml_compute_forward_swiglu_f16(
|
|
3592
3423
|
}
|
3593
3424
|
}
|
3594
3425
|
|
3595
|
-
static void
|
3426
|
+
static void ggml_compute_forward_geglu_quick(
|
3596
3427
|
const ggml_compute_params * params,
|
3597
3428
|
ggml_tensor * dst) {
|
3598
3429
|
|
@@ -3601,11 +3432,11 @@ static void ggml_compute_forward_swiglu(
|
|
3601
3432
|
switch (src0->type) {
|
3602
3433
|
case GGML_TYPE_F32:
|
3603
3434
|
{
|
3604
|
-
|
3435
|
+
ggml_compute_forward_geglu_quick_f32(params, dst);
|
3605
3436
|
} break;
|
3606
3437
|
case GGML_TYPE_F16:
|
3607
3438
|
{
|
3608
|
-
|
3439
|
+
ggml_compute_forward_geglu_quick_f16(params, dst);
|
3609
3440
|
} break;
|
3610
3441
|
default:
|
3611
3442
|
{
|
@@ -3729,6 +3560,9 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
3729
3560
|
|
3730
3561
|
const float scale = 1.0f/sqrtf(mean + eps);
|
3731
3562
|
|
3563
|
+
// if you hit this, likely you got an inf somewhere earlier
|
3564
|
+
assert(scale > 0.0f);
|
3565
|
+
|
3732
3566
|
ggml_vec_scale_f32(ne00, y, scale);
|
3733
3567
|
}
|
3734
3568
|
}
|
@@ -4310,6 +4144,7 @@ void ggml_compute_forward_out_prod(
|
|
4310
4144
|
case GGML_TYPE_Q5_0:
|
4311
4145
|
case GGML_TYPE_Q5_1:
|
4312
4146
|
case GGML_TYPE_Q8_0:
|
4147
|
+
case GGML_TYPE_MXFP4:
|
4313
4148
|
case GGML_TYPE_Q2_K:
|
4314
4149
|
case GGML_TYPE_Q3_K:
|
4315
4150
|
case GGML_TYPE_Q4_K:
|
@@ -4357,9 +4192,11 @@ static void ggml_compute_forward_scale_f32(
|
|
4357
4192
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
4358
4193
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
4359
4194
|
|
4360
|
-
// scale factor
|
4361
|
-
float
|
4362
|
-
|
4195
|
+
float s; // scale factor
|
4196
|
+
float b; // bias
|
4197
|
+
|
4198
|
+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
|
4199
|
+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
|
4363
4200
|
|
4364
4201
|
const int ith = params->ith;
|
4365
4202
|
const int nth = params->nth;
|
@@ -4378,12 +4215,22 @@ static void ggml_compute_forward_scale_f32(
|
|
4378
4215
|
|
4379
4216
|
const size_t nb1 = dst->nb[1];
|
4380
4217
|
|
4381
|
-
|
4382
|
-
|
4383
|
-
|
4384
|
-
|
4218
|
+
if (b == 0.0f) {
|
4219
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
4220
|
+
if (dst->data != src0->data) {
|
4221
|
+
// src0 is same shape as dst => same indices
|
4222
|
+
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
|
4223
|
+
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
4224
|
+
}
|
4225
|
+
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
4226
|
+
}
|
4227
|
+
} else {
|
4228
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
4229
|
+
ggml_vec_mad1_f32(nc,
|
4230
|
+
(float *) ((char *) dst->data + i1*nb1),
|
4231
|
+
(float *) ((char *) src0->data + i1*nb1),
|
4232
|
+
s, b);
|
4385
4233
|
}
|
4386
|
-
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
4387
4234
|
}
|
4388
4235
|
}
|
4389
4236
|
|
@@ -4572,6 +4419,7 @@ void ggml_compute_forward_set(
|
|
4572
4419
|
case GGML_TYPE_Q5_1:
|
4573
4420
|
case GGML_TYPE_Q8_0:
|
4574
4421
|
case GGML_TYPE_Q8_1:
|
4422
|
+
case GGML_TYPE_MXFP4:
|
4575
4423
|
case GGML_TYPE_Q2_K:
|
4576
4424
|
case GGML_TYPE_Q3_K:
|
4577
4425
|
case GGML_TYPE_Q4_K:
|
@@ -4833,6 +4681,7 @@ void ggml_compute_forward_get_rows(
|
|
4833
4681
|
case GGML_TYPE_Q5_1:
|
4834
4682
|
case GGML_TYPE_Q8_0:
|
4835
4683
|
case GGML_TYPE_Q8_1:
|
4684
|
+
case GGML_TYPE_MXFP4:
|
4836
4685
|
case GGML_TYPE_Q2_K:
|
4837
4686
|
case GGML_TYPE_Q3_K:
|
4838
4687
|
case GGML_TYPE_Q4_K:
|
@@ -4890,6 +4739,7 @@ void ggml_compute_forward_get_rows(
|
|
4890
4739
|
//}
|
4891
4740
|
}
|
4892
4741
|
|
4742
|
+
template<typename idx_t>
|
4893
4743
|
static void ggml_compute_forward_set_rows_f32(
|
4894
4744
|
const ggml_compute_params * params,
|
4895
4745
|
ggml_tensor * dst) {
|
@@ -4928,7 +4778,7 @@ static void ggml_compute_forward_set_rows_f32(
|
|
4928
4778
|
const int64_t i11 = i02%ne11;
|
4929
4779
|
const int64_t i10 = i;
|
4930
4780
|
|
4931
|
-
const int64_t i1 = *(
|
4781
|
+
const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
4932
4782
|
|
4933
4783
|
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
4934
4784
|
|
@@ -4945,11 +4795,18 @@ void ggml_compute_forward_set_rows(
|
|
4945
4795
|
ggml_tensor * dst) {
|
4946
4796
|
|
4947
4797
|
const ggml_tensor * src0 = dst->src[0];
|
4798
|
+
const ggml_tensor * src1 = dst->src[1];
|
4948
4799
|
|
4949
4800
|
switch (src0->type) {
|
4950
4801
|
case GGML_TYPE_F32:
|
4951
4802
|
{
|
4952
|
-
|
4803
|
+
if (src1->type == GGML_TYPE_I64) {
|
4804
|
+
ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
|
4805
|
+
} else if (src1->type == GGML_TYPE_I32) {
|
4806
|
+
ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
|
4807
|
+
} else {
|
4808
|
+
GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
|
4809
|
+
}
|
4953
4810
|
} break;
|
4954
4811
|
default:
|
4955
4812
|
{
|
@@ -5222,6 +5079,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
5222
5079
|
|
5223
5080
|
const ggml_tensor * src0 = dst->src[0];
|
5224
5081
|
const ggml_tensor * src1 = dst->src[1];
|
5082
|
+
const ggml_tensor * src2 = dst->src[2];
|
5225
5083
|
|
5226
5084
|
assert(ggml_is_contiguous(dst));
|
5227
5085
|
assert(ggml_are_same_shape(src0, dst));
|
@@ -5232,14 +5090,17 @@ static void ggml_compute_forward_soft_max_f32(
|
|
5232
5090
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
5233
5091
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
5234
5092
|
|
5235
|
-
// TODO: handle transposed/permuted matrices
|
5236
|
-
|
5237
5093
|
const int ith = params->ith;
|
5238
5094
|
const int nth = params->nth;
|
5239
5095
|
|
5240
5096
|
GGML_TENSOR_UNARY_OP_LOCALS
|
5241
5097
|
|
5242
|
-
|
5098
|
+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
5099
|
+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
5100
|
+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
5101
|
+
|
5102
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
5103
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
5243
5104
|
|
5244
5105
|
// TODO: is this supposed to be ceil instead of floor?
|
5245
5106
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
@@ -5249,68 +5110,78 @@ static void ggml_compute_forward_soft_max_f32(
|
|
5249
5110
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
5250
5111
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
5251
5112
|
|
5252
|
-
|
5253
|
-
const int nr = ggml_nrows(src0);
|
5254
|
-
|
5255
|
-
// rows per thread
|
5256
|
-
const int dr = (nr + nth - 1)/nth;
|
5257
|
-
|
5258
|
-
// row range for this thread
|
5259
|
-
const int ir0 = dr*ith;
|
5260
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
5261
|
-
|
5262
|
-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
5113
|
+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
5263
5114
|
|
5264
5115
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
5265
5116
|
|
5266
|
-
|
5267
|
-
|
5268
|
-
const uint32_t h = (i1/ne01)%ne02; // head
|
5269
|
-
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
5270
|
-
|
5271
|
-
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
5272
|
-
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
5117
|
+
// sinks
|
5118
|
+
const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
|
5273
5119
|
|
5274
|
-
|
5275
|
-
|
5276
|
-
|
5277
|
-
|
5278
|
-
|
5279
|
-
|
5280
|
-
|
5281
|
-
|
5282
|
-
|
5283
|
-
|
5284
|
-
|
5285
|
-
|
5286
|
-
|
5287
|
-
|
5120
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5121
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5122
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
5123
|
+
const int64_t i11 = i01;
|
5124
|
+
const int64_t i12 = i02%ne12;
|
5125
|
+
const int64_t i13 = i03%ne13;
|
5126
|
+
|
5127
|
+
// ALiBi
|
5128
|
+
const uint32_t h = i02; // head
|
5129
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
5130
|
+
|
5131
|
+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5132
|
+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
5133
|
+
|
5134
|
+
// broadcast the mask across rows
|
5135
|
+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
5136
|
+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
5137
|
+
|
5138
|
+
ggml_vec_cpy_f32 (ne00, wp, sp);
|
5139
|
+
ggml_vec_scale_f32(ne00, wp, scale);
|
5140
|
+
if (mp_f32) {
|
5141
|
+
if (use_f16) {
|
5142
|
+
for (int i = 0; i < ne00; ++i) {
|
5143
|
+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
5144
|
+
}
|
5145
|
+
} else {
|
5146
|
+
for (int i = 0; i < ne00; ++i) {
|
5147
|
+
wp[i] += slope*mp_f32[i];
|
5148
|
+
}
|
5149
|
+
}
|
5288
5150
|
}
|
5289
|
-
}
|
5290
|
-
}
|
5291
5151
|
|
5292
5152
|
#ifndef NDEBUG
|
5293
|
-
|
5294
|
-
|
5295
|
-
|
5296
|
-
|
5153
|
+
for (int i = 0; i < ne00; ++i) {
|
5154
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
5155
|
+
assert(!isnan(wp[i]));
|
5156
|
+
}
|
5297
5157
|
#endif
|
5298
5158
|
|
5299
|
-
|
5300
|
-
|
5159
|
+
float max = -INFINITY;
|
5160
|
+
ggml_vec_max_f32(ne00, &max, wp);
|
5301
5161
|
|
5302
|
-
|
5303
|
-
|
5162
|
+
// if we have sinks, make a correction as if they were included in the softmax
|
5163
|
+
if (sk) {
|
5164
|
+
max = MAX(max, sk[i02]);
|
5165
|
+
}
|
5166
|
+
|
5167
|
+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
5168
|
+
assert(sum > 0.0);
|
5169
|
+
|
5170
|
+
if (sk) {
|
5171
|
+
sum += (ggml_float) expf(sk[i02] - max);
|
5172
|
+
}
|
5304
5173
|
|
5305
|
-
|
5306
|
-
|
5174
|
+
sum = 1.0/sum;
|
5175
|
+
ggml_vec_scale_f32(ne00, dp, sum);
|
5307
5176
|
|
5308
5177
|
#ifndef NDEBUG
|
5309
|
-
|
5310
|
-
|
5311
|
-
|
5312
|
-
|
5178
|
+
for (int i = 0; i < ne00; ++i) {
|
5179
|
+
assert(!isnan(dp[i]));
|
5180
|
+
assert(!isinf(dp[i]));
|
5181
|
+
}
|
5313
5182
|
#endif
|
5183
|
+
}
|
5184
|
+
}
|
5314
5185
|
}
|
5315
5186
|
}
|
5316
5187
|
|
@@ -5534,6 +5405,7 @@ void ggml_compute_forward_clamp(
|
|
5534
5405
|
case GGML_TYPE_Q5_1:
|
5535
5406
|
case GGML_TYPE_Q8_0:
|
5536
5407
|
case GGML_TYPE_Q8_1:
|
5408
|
+
case GGML_TYPE_MXFP4:
|
5537
5409
|
case GGML_TYPE_Q2_K:
|
5538
5410
|
case GGML_TYPE_Q3_K:
|
5539
5411
|
case GGML_TYPE_Q4_K:
|
@@ -6460,7 +6332,195 @@ void ggml_compute_forward_im2col_back_f32(
|
|
6460
6332
|
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
6461
6333
|
const ggml_tensor * src1 = dst->src[1]; // convolution kernel
|
6462
6334
|
|
6463
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
6335
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
6336
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6337
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
6338
|
+
|
6339
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
6340
|
+
|
6341
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
6342
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
6343
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
6344
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
6345
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
6346
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
6347
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
6348
|
+
|
6349
|
+
const int ith = params->ith;
|
6350
|
+
const int nth = params->nth;
|
6351
|
+
|
6352
|
+
const int64_t N = is_2D ? ne3 : ne2;
|
6353
|
+
const int64_t IC = is_2D ? ne2 : ne1;
|
6354
|
+
const int64_t IH = is_2D ? ne1 : 1;
|
6355
|
+
const int64_t IW = ne0;
|
6356
|
+
|
6357
|
+
const int64_t KH = is_2D ? ne11 : 1;
|
6358
|
+
const int64_t KW = ne10;
|
6359
|
+
|
6360
|
+
const int64_t OH = is_2D ? ne02 : 1;
|
6361
|
+
const int64_t OW = ne01;
|
6362
|
+
|
6363
|
+
int ofs0 = is_2D ? nb3 : nb2;
|
6364
|
+
int ofs1 = is_2D ? nb2 : nb1;
|
6365
|
+
|
6366
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
6367
|
+
|
6368
|
+
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
6369
|
+
{
|
6370
|
+
float * const wdata = (float *) dst->data;
|
6371
|
+
|
6372
|
+
for (int64_t in = 0; in < N; in++) {
|
6373
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
6374
|
+
for (int64_t iih = 0; iih < IH; iih++) {
|
6375
|
+
for (int64_t iiw = 0; iiw < IW; iiw++) {
|
6376
|
+
|
6377
|
+
// micro kernel
|
6378
|
+
float grad = 0.0f;
|
6379
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
6380
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
6381
|
+
// For s0 > 1 some values were skipped over in the forward pass.
|
6382
|
+
// These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
|
6383
|
+
const int64_t tmpw = (iiw + p0 - ikw*d0);
|
6384
|
+
if (tmpw % s0 != 0) {
|
6385
|
+
continue;
|
6386
|
+
}
|
6387
|
+
const int64_t iow = tmpw / s0;
|
6388
|
+
|
6389
|
+
// Equivalent logic as above except for s1.
|
6390
|
+
int64_t ioh;
|
6391
|
+
if (is_2D) {
|
6392
|
+
const int64_t tmph = iih + p1 - ikh*d1;
|
6393
|
+
|
6394
|
+
if (tmph % s1 != 0) {
|
6395
|
+
continue;
|
6396
|
+
}
|
6397
|
+
|
6398
|
+
ioh = tmph / s1;
|
6399
|
+
} else {
|
6400
|
+
ioh = 0;
|
6401
|
+
}
|
6402
|
+
|
6403
|
+
if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
|
6404
|
+
continue;
|
6405
|
+
}
|
6406
|
+
|
6407
|
+
const float * const grad_in = (const float *) src0->data
|
6408
|
+
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
6409
|
+
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
6410
|
+
}
|
6411
|
+
}
|
6412
|
+
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
6413
|
+
dst_data[iih*IW + iiw] = grad;
|
6414
|
+
}
|
6415
|
+
}
|
6416
|
+
}
|
6417
|
+
}
|
6418
|
+
}
|
6419
|
+
}
|
6420
|
+
|
6421
|
+
|
6422
|
+
// ggml_compute_forward_im2col_3d_f16
|
6423
|
+
// src0: kernel [OC*IC, KD, KH, KW]
|
6424
|
+
// src1: image [N*IC, ID, IH, IW]
|
6425
|
+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
|
6426
|
+
static void ggml_compute_forward_im2col_3d_f16(
|
6427
|
+
const ggml_compute_params * params,
|
6428
|
+
ggml_tensor * dst) {
|
6429
|
+
|
6430
|
+
const ggml_tensor * src0 = dst->src[0];
|
6431
|
+
const ggml_tensor * src1 = dst->src[1];
|
6432
|
+
|
6433
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6434
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6435
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
6436
|
+
|
6437
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
6438
|
+
|
6439
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
6440
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
6441
|
+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
6442
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
6443
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
6444
|
+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
6445
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
6446
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
6447
|
+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
6448
|
+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
6449
|
+
|
6450
|
+
|
6451
|
+
const int ith = params->ith;
|
6452
|
+
const int nth = params->nth;
|
6453
|
+
|
6454
|
+
const int64_t N = ne13 / IC;
|
6455
|
+
const int64_t ID = ne12;
|
6456
|
+
const int64_t IH = ne11;
|
6457
|
+
const int64_t IW = ne10;
|
6458
|
+
|
6459
|
+
const int64_t OC = ne03 / IC;
|
6460
|
+
GGML_UNUSED(OC);
|
6461
|
+
const int64_t KD = ne02;
|
6462
|
+
const int64_t KH = ne01;
|
6463
|
+
const int64_t KW = ne00;
|
6464
|
+
|
6465
|
+
const int64_t OD = ne3 / N;
|
6466
|
+
const int64_t OH = ne2;
|
6467
|
+
const int64_t OW = ne1;
|
6468
|
+
const int64_t OH_OW = OH*OW;
|
6469
|
+
const int64_t KD_KH_KW = KD*KH*KW;
|
6470
|
+
const int64_t KH_KW = KH*KW;
|
6471
|
+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
6472
|
+
|
6473
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
6474
|
+
|
6475
|
+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
6476
|
+
{
|
6477
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
|
6478
|
+
|
6479
|
+
for (int64_t in = 0; in < N; in++) {
|
6480
|
+
for (int64_t iod = 0; iod < OD; iod++) {
|
6481
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
6482
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
6483
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
6484
|
+
|
6485
|
+
// micro kernel
|
6486
|
+
ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
6487
|
+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
|
6488
|
+
|
6489
|
+
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
6490
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
6491
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
6492
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
6493
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
6494
|
+
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
6495
|
+
|
6496
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
6497
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
6498
|
+
} else {
|
6499
|
+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
6500
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
|
6501
|
+
}
|
6502
|
+
}
|
6503
|
+
}
|
6504
|
+
}
|
6505
|
+
}
|
6506
|
+
}
|
6507
|
+
}
|
6508
|
+
}
|
6509
|
+
}
|
6510
|
+
}
|
6511
|
+
}
|
6512
|
+
|
6513
|
+
// ggml_compute_forward_im2col_3d_f32
|
6514
|
+
// src0: kernel [OC*IC, KD, KH, KW]
|
6515
|
+
// src1: image [N*IC, ID, IH, IW]
|
6516
|
+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
|
6517
|
+
static void ggml_compute_forward_im2col_3d_f32(
|
6518
|
+
const ggml_compute_params * params,
|
6519
|
+
ggml_tensor * dst) {
|
6520
|
+
|
6521
|
+
const ggml_tensor * src0 = dst->src[0];
|
6522
|
+
const ggml_tensor * src1 = dst->src[1];
|
6523
|
+
|
6464
6524
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6465
6525
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
6466
6526
|
|
@@ -6468,77 +6528,72 @@ void ggml_compute_forward_im2col_back_f32(
|
|
6468
6528
|
|
6469
6529
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
6470
6530
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
6471
|
-
const int32_t
|
6472
|
-
const int32_t
|
6473
|
-
const int32_t
|
6474
|
-
const int32_t
|
6475
|
-
const
|
6531
|
+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
6532
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
6533
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
6534
|
+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
6535
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
6536
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
6537
|
+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
6538
|
+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
6539
|
+
|
6476
6540
|
|
6477
6541
|
const int ith = params->ith;
|
6478
6542
|
const int nth = params->nth;
|
6479
6543
|
|
6480
|
-
const int64_t N =
|
6481
|
-
const int64_t
|
6482
|
-
const int64_t IH =
|
6483
|
-
const int64_t IW =
|
6544
|
+
const int64_t N = ne13 / IC;
|
6545
|
+
const int64_t ID = ne12;
|
6546
|
+
const int64_t IH = ne11;
|
6547
|
+
const int64_t IW = ne10;
|
6484
6548
|
|
6485
|
-
const int64_t
|
6486
|
-
|
6549
|
+
const int64_t OC = ne03 / IC;
|
6550
|
+
GGML_UNUSED(OC);
|
6551
|
+
const int64_t KD = ne02;
|
6552
|
+
const int64_t KH = ne01;
|
6553
|
+
const int64_t KW = ne00;
|
6487
6554
|
|
6488
|
-
const int64_t
|
6489
|
-
const int64_t
|
6555
|
+
const int64_t OD = ne3 / N;
|
6556
|
+
const int64_t OH = ne2;
|
6557
|
+
const int64_t OW = ne1;
|
6490
6558
|
|
6491
|
-
|
6492
|
-
|
6559
|
+
const int64_t OH_OW = OH*OW;
|
6560
|
+
const int64_t KD_KH_KW = KD*KH*KW;
|
6561
|
+
const int64_t KH_KW = KH*KW;
|
6562
|
+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
6493
6563
|
|
6494
|
-
GGML_ASSERT(
|
6564
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
6495
6565
|
|
6496
|
-
// im2col: [N,
|
6566
|
+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
6497
6567
|
{
|
6498
6568
|
float * const wdata = (float *) dst->data;
|
6499
6569
|
|
6500
6570
|
for (int64_t in = 0; in < N; in++) {
|
6501
|
-
for (int64_t
|
6502
|
-
for (int64_t
|
6503
|
-
for (int64_t
|
6504
|
-
|
6505
|
-
|
6506
|
-
|
6507
|
-
|
6508
|
-
|
6509
|
-
|
6510
|
-
|
6511
|
-
|
6512
|
-
|
6513
|
-
|
6514
|
-
|
6515
|
-
|
6516
|
-
|
6517
|
-
|
6518
|
-
|
6519
|
-
|
6520
|
-
|
6521
|
-
|
6522
|
-
|
6523
|
-
continue;
|
6571
|
+
for (int64_t iod = 0; iod < OD; iod++) {
|
6572
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
6573
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
6574
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
6575
|
+
|
6576
|
+
// micro kernel
|
6577
|
+
float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
6578
|
+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
|
6579
|
+
|
6580
|
+
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
6581
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
6582
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
6583
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
6584
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
6585
|
+
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
6586
|
+
|
6587
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
6588
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
6589
|
+
} else {
|
6590
|
+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
6591
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
|
6592
|
+
}
|
6524
6593
|
}
|
6525
|
-
|
6526
|
-
ioh = tmph / s1;
|
6527
|
-
} else {
|
6528
|
-
ioh = 0;
|
6529
|
-
}
|
6530
|
-
|
6531
|
-
if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
|
6532
|
-
continue;
|
6533
6594
|
}
|
6534
|
-
|
6535
|
-
const float * const grad_in = (const float *) src0->data
|
6536
|
-
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
6537
|
-
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
6538
6595
|
}
|
6539
6596
|
}
|
6540
|
-
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
6541
|
-
dst_data[iih*IW + iiw] = grad;
|
6542
6597
|
}
|
6543
6598
|
}
|
6544
6599
|
}
|
@@ -6546,6 +6601,26 @@ void ggml_compute_forward_im2col_back_f32(
|
|
6546
6601
|
}
|
6547
6602
|
}
|
6548
6603
|
|
6604
|
+
|
6605
|
+
void ggml_compute_forward_im2col_3d(
|
6606
|
+
const ggml_compute_params * params,
|
6607
|
+
ggml_tensor * dst) {
|
6608
|
+
switch (dst->type) {
|
6609
|
+
case GGML_TYPE_F16:
|
6610
|
+
{
|
6611
|
+
ggml_compute_forward_im2col_3d_f16(params, dst);
|
6612
|
+
} break;
|
6613
|
+
case GGML_TYPE_F32:
|
6614
|
+
{
|
6615
|
+
ggml_compute_forward_im2col_3d_f32(params, dst);
|
6616
|
+
} break;
|
6617
|
+
default:
|
6618
|
+
{
|
6619
|
+
GGML_ABORT("fatal error");
|
6620
|
+
}
|
6621
|
+
}
|
6622
|
+
}
|
6623
|
+
|
6549
6624
|
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
6550
6625
|
void * a, void * b, float * c) {
|
6551
6626
|
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
@@ -6726,6 +6801,148 @@ void ggml_compute_forward_conv_2d(
|
|
6726
6801
|
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
6727
6802
|
}
|
6728
6803
|
|
6804
|
+
// ggml_compute_forward_conv_3d
|
6805
|
+
|
6806
|
+
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
|
6807
|
+
const ggml_tensor * kernel,
|
6808
|
+
const ggml_tensor * src,
|
6809
|
+
ggml_tensor * dst,
|
6810
|
+
ggml_type kernel_type) {
|
6811
|
+
|
6812
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
6813
|
+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
6814
|
+
GGML_ASSERT(kernel->type == kernel_type);
|
6815
|
+
|
6816
|
+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
6817
|
+
|
6818
|
+
const int32_t s0 = dst->op_params[0];
|
6819
|
+
const int32_t s1 = dst->op_params[1];
|
6820
|
+
const int32_t s2 = dst->op_params[2];
|
6821
|
+
const int32_t p0 = dst->op_params[3];
|
6822
|
+
const int32_t p1 = dst->op_params[4];
|
6823
|
+
const int32_t p2 = dst->op_params[5];
|
6824
|
+
const int32_t d0 = dst->op_params[6];
|
6825
|
+
const int32_t d1 = dst->op_params[7];
|
6826
|
+
const int32_t d2 = dst->op_params[8];
|
6827
|
+
const int32_t c = dst->op_params[9];
|
6828
|
+
const int32_t n = dst->op_params[10];
|
6829
|
+
const int32_t oc = dst->op_params[11];
|
6830
|
+
|
6831
|
+
const int64_t src_w = src->ne[0];
|
6832
|
+
const int64_t src_h = src->ne[1];
|
6833
|
+
const int64_t src_d = src->ne[2];
|
6834
|
+
const int64_t knl_w = kernel->ne[0];
|
6835
|
+
const int64_t knl_h = kernel->ne[1];
|
6836
|
+
const int64_t knl_d = kernel->ne[2];
|
6837
|
+
const int64_t dst_w = dst->ne[0];
|
6838
|
+
const int64_t dst_h = dst->ne[1];
|
6839
|
+
const int64_t dst_d = dst->ne[2];
|
6840
|
+
|
6841
|
+
const float * src_data = (float *) src->data;
|
6842
|
+
void * knl_data = kernel->data;
|
6843
|
+
float * dst_data = (float *) dst->data;
|
6844
|
+
|
6845
|
+
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
6846
|
+
const int64_t knl_n_total = knl_n_per_channel * c;
|
6847
|
+
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
6848
|
+
|
6849
|
+
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
6850
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
6851
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
6852
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
6853
|
+
|
6854
|
+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
6855
|
+
|
6856
|
+
void * tmp = params->wdata;
|
6857
|
+
|
6858
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
6859
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
6860
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
6861
|
+
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
6862
|
+
|
6863
|
+
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
6864
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
6865
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
6866
|
+
|
6867
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
6868
|
+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
6869
|
+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
6870
|
+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
6871
|
+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
6872
|
+
const int64_t dst_y = p_in_depth / dst_w;
|
6873
|
+
const int64_t dst_x = p_in_depth % dst_w;
|
6874
|
+
|
6875
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
6876
|
+
|
6877
|
+
for (int64_t ic = 0; ic < c; ++ic) {
|
6878
|
+
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
6879
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
6880
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
6881
|
+
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
6882
|
+
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
6883
|
+
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
6884
|
+
|
6885
|
+
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
6886
|
+
|
6887
|
+
float src_val;
|
6888
|
+
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
6889
|
+
src_val = 0.0f;
|
6890
|
+
} else {
|
6891
|
+
const int64_t cn_idx = batch_idx * c + ic;
|
6892
|
+
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
6893
|
+
src_val = *src_ptr;
|
6894
|
+
}
|
6895
|
+
|
6896
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
6897
|
+
if (kernel_type == GGML_TYPE_F32) {
|
6898
|
+
*(float *)element_ptr = src_val;
|
6899
|
+
} else if (kernel_type == GGML_TYPE_F16) {
|
6900
|
+
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
6901
|
+
}
|
6902
|
+
}
|
6903
|
+
}
|
6904
|
+
}
|
6905
|
+
}
|
6906
|
+
}
|
6907
|
+
|
6908
|
+
ggml_barrier(params->threadpool);
|
6909
|
+
|
6910
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
6911
|
+
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
6912
|
+
|
6913
|
+
ggml_barrier(params->threadpool);
|
6914
|
+
|
6915
|
+
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
6916
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
6917
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
6918
|
+
|
6919
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
6920
|
+
const int64_t p = patch_start_batch + i;
|
6921
|
+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
6922
|
+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
6923
|
+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
6924
|
+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
6925
|
+
const int64_t dst_y = p_in_depth / dst_w;
|
6926
|
+
const int64_t dst_x = p_in_depth % dst_w;
|
6927
|
+
|
6928
|
+
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
6929
|
+
const float value = gemm_output[i * oc + ioc];
|
6930
|
+
const int64_t ocn_idx = batch_idx * oc + ioc;
|
6931
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
6932
|
+
*dst_ptr = value;
|
6933
|
+
}
|
6934
|
+
}
|
6935
|
+
}
|
6936
|
+
}
|
6937
|
+
|
6938
|
+
void ggml_compute_forward_conv_3d(
|
6939
|
+
const ggml_compute_params * params,
|
6940
|
+
ggml_tensor * dst) {
|
6941
|
+
const ggml_tensor * src0 = dst->src[0];
|
6942
|
+
const ggml_tensor * src1 = dst->src[1];
|
6943
|
+
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
6944
|
+
}
|
6945
|
+
|
6729
6946
|
// ggml_compute_forward_conv_transpose_2d
|
6730
6947
|
|
6731
6948
|
void ggml_compute_forward_conv_transpose_2d(
|
@@ -7391,6 +7608,15 @@ static void ggml_compute_forward_pad_f32(
|
|
7391
7608
|
GGML_TENSOR_UNARY_OP_LOCALS
|
7392
7609
|
|
7393
7610
|
float * dst_ptr = (float *) dst->data;
|
7611
|
+
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
|
7612
|
+
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
|
7613
|
+
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
|
7614
|
+
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
|
7615
|
+
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
|
7616
|
+
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
|
7617
|
+
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
7618
|
+
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
7619
|
+
|
7394
7620
|
|
7395
7621
|
// TODO: optimize
|
7396
7622
|
|
@@ -7399,10 +7625,12 @@ static void ggml_compute_forward_pad_f32(
|
|
7399
7625
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
7400
7626
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
7401
7627
|
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
7402
|
-
|
7403
|
-
|
7404
|
-
|
7405
|
-
|
7628
|
+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
7629
|
+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
7630
|
+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
7631
|
+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
7632
|
+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
7633
|
+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
7406
7634
|
dst_ptr[dst_idx] = *src_ptr;
|
7407
7635
|
} else {
|
7408
7636
|
dst_ptr[dst_idx] = 0;
|
@@ -7601,7 +7829,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
|
|
7601
7829
|
embed_data[j + half] = sinf(arg);
|
7602
7830
|
}
|
7603
7831
|
if (dim % 2 != 0 && ith == 0) {
|
7604
|
-
embed_data[
|
7832
|
+
embed_data[2 * half] = 0.f;
|
7605
7833
|
}
|
7606
7834
|
}
|
7607
7835
|
}
|
@@ -7687,12 +7915,14 @@ void ggml_compute_forward_argsort(
|
|
7687
7915
|
|
7688
7916
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
7689
7917
|
const ggml_compute_params * params,
|
7690
|
-
const ggml_tensor * q,
|
7691
|
-
const ggml_tensor * k,
|
7692
|
-
const ggml_tensor * v,
|
7693
|
-
const ggml_tensor * mask,
|
7694
7918
|
ggml_tensor * dst) {
|
7695
7919
|
|
7920
|
+
const ggml_tensor * q = dst->src[0];
|
7921
|
+
const ggml_tensor * k = dst->src[1];
|
7922
|
+
const ggml_tensor * v = dst->src[2];
|
7923
|
+
const ggml_tensor * mask = dst->src[3];
|
7924
|
+
const ggml_tensor * sinks = dst->src[4];
|
7925
|
+
|
7696
7926
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
7697
7927
|
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
7698
7928
|
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
@@ -7766,7 +7996,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
7766
7996
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
7767
7997
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
7768
7998
|
|
7769
|
-
ggml_type
|
7999
|
+
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
7770
8000
|
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
7771
8001
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
7772
8002
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
@@ -7798,7 +8028,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
7798
8028
|
memset(VKQ32, 0, DV*sizeof(float));
|
7799
8029
|
}
|
7800
8030
|
|
7801
|
-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
8031
|
+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
7802
8032
|
|
7803
8033
|
// k indices
|
7804
8034
|
const int ik3 = iq3 / rk3;
|
@@ -7887,6 +8117,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
7887
8117
|
}
|
7888
8118
|
}
|
7889
8119
|
|
8120
|
+
// sinks
|
8121
|
+
if (sinks) {
|
8122
|
+
const float s = ((float *)((char *) sinks->data))[h];
|
8123
|
+
|
8124
|
+
float ms = 1.0f;
|
8125
|
+
float vs = 1.0f;
|
8126
|
+
|
8127
|
+
if (s > M) {
|
8128
|
+
ms = expf(M - s);
|
8129
|
+
ggml_vec_scale_f32(DV, VKQ32, ms);
|
8130
|
+
} else {
|
8131
|
+
vs = expf(s - M);
|
8132
|
+
}
|
8133
|
+
|
8134
|
+
S = S*ms + vs;
|
8135
|
+
}
|
8136
|
+
|
7890
8137
|
// V /= S
|
7891
8138
|
const float S_inv = 1.0f/S;
|
7892
8139
|
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
@@ -7906,17 +8153,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
7906
8153
|
|
7907
8154
|
void ggml_compute_forward_flash_attn_ext(
|
7908
8155
|
const ggml_compute_params * params,
|
7909
|
-
const ggml_tensor * q,
|
7910
|
-
const ggml_tensor * k,
|
7911
|
-
const ggml_tensor * v,
|
7912
|
-
const ggml_tensor * mask,
|
7913
8156
|
ggml_tensor * dst) {
|
7914
8157
|
switch (dst->op_params[3]) {
|
7915
8158
|
case GGML_PREC_DEFAULT:
|
7916
8159
|
case GGML_PREC_F32:
|
7917
8160
|
{
|
7918
8161
|
// uses F32 accumulators
|
7919
|
-
ggml_compute_forward_flash_attn_ext_f16(params,
|
8162
|
+
ggml_compute_forward_flash_attn_ext_f16(params, dst);
|
7920
8163
|
} break;
|
7921
8164
|
default:
|
7922
8165
|
{
|
@@ -8336,120 +8579,214 @@ void ggml_compute_forward_ssm_conv(
|
|
8336
8579
|
static void ggml_compute_forward_ssm_scan_f32(
|
8337
8580
|
const ggml_compute_params * params,
|
8338
8581
|
ggml_tensor * dst) {
|
8339
|
-
const ggml_tensor * src0 = dst->src[0]; // s
|
8340
|
-
const ggml_tensor * src1 = dst->src[1]; // x
|
8341
|
-
const ggml_tensor * src2 = dst->src[2]; // dt
|
8342
|
-
const ggml_tensor * src3 = dst->src[3]; // A
|
8343
|
-
const ggml_tensor * src4 = dst->src[4]; // B
|
8344
|
-
const ggml_tensor * src5 = dst->src[5]; // C
|
8582
|
+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
8583
|
+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
8584
|
+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
8585
|
+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
8586
|
+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
8587
|
+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
8588
|
+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
8345
8589
|
|
8346
8590
|
const int ith = params->ith;
|
8347
8591
|
const int nth = params->nth;
|
8348
8592
|
|
8349
|
-
const int64_t nc
|
8350
|
-
const int64_t nr
|
8351
|
-
const int64_t
|
8352
|
-
const int64_t
|
8593
|
+
const int64_t nc = src0->ne[0]; // d_state
|
8594
|
+
const int64_t nr = src0->ne[1]; // dim
|
8595
|
+
const int64_t nh = src1->ne[1]; // n_head
|
8596
|
+
const int64_t ng = src4->ne[1];
|
8597
|
+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
8598
|
+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
8599
|
+
|
8600
|
+
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
8601
|
+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
8353
8602
|
|
8354
|
-
GGML_ASSERT(ggml_nelements(src1) +
|
8603
|
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
8355
8604
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
8356
8605
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
8357
8606
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
8358
8607
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
8359
8608
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
8360
8609
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
8361
|
-
|
8362
|
-
GGML_ASSERT(
|
8363
|
-
// required for per-sequence offsets for states
|
8364
|
-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
8365
|
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
8366
|
-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
8610
|
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
8611
|
+
GGML_ASSERT(nh % ng == 0);
|
8367
8612
|
|
8368
|
-
//
|
8369
|
-
const int
|
8613
|
+
// heads per thread
|
8614
|
+
const int dh = (nh + nth - 1)/nth;
|
8370
8615
|
|
8371
|
-
//
|
8372
|
-
const int
|
8373
|
-
const int
|
8374
|
-
|
8616
|
+
// head range for this thread
|
8617
|
+
const int ih0 = dh*ith;
|
8618
|
+
const int ih1 = MIN(ih0 + dh, nh);
|
8619
|
+
|
8620
|
+
const int32_t * ids = (const int32_t *) src6->data;
|
8621
|
+
|
8622
|
+
for (int i3 = 0; i3 < ns; ++i3) {
|
8623
|
+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
8624
|
+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
8625
|
+
|
8626
|
+
for (int i2 = 0; i2 < nt; ++i2) {
|
8627
|
+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
8628
|
+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
8629
|
+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
8630
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
8631
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
8632
|
+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
8633
|
+
|
8634
|
+
if (src3->ne[0] == 1) {
|
8635
|
+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
8636
|
+
|
8637
|
+
// n_head
|
8638
|
+
for (int h = ih0; h < ih1; ++h) {
|
8639
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
8640
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
8641
|
+
const float dA = expf(dt_soft_plus * A[h]);
|
8642
|
+
const int g = h / (nh / ng); // repeat_interleave
|
8643
|
+
|
8644
|
+
// dim
|
8645
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
8646
|
+
const int ii = i1 + h*nr;
|
8647
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
8648
|
+
float sumf = 0.0f;
|
8649
|
+
#if defined(GGML_SIMD)
|
8650
|
+
#if defined(__ARM_FEATURE_SVE)
|
8651
|
+
const int ggml_f32_epr = svcntw();
|
8652
|
+
const int ggml_f32_step = 1 * ggml_f32_epr;
|
8653
|
+
|
8654
|
+
const int np = (nc & ~(ggml_f32_step - 1));
|
8655
|
+
|
8656
|
+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
8657
|
+
|
8658
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
8659
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
8660
|
+
|
8661
|
+
for (int i = 0; i < np; i += ggml_f32_step) {
|
8662
|
+
// TODO: maybe unroll more?
|
8663
|
+
for (int j = 0; j < 1; j++) {
|
8664
|
+
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
8665
|
+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
|
8666
|
+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
|
8667
|
+
|
8668
|
+
t0 = GGML_F32_VEC_MUL(t0, adA);
|
8669
|
+
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
8670
|
+
|
8671
|
+
t0 = GGML_F32_VEC_ADD(t0, t1);
|
8672
|
+
|
8673
|
+
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
8674
|
+
|
8675
|
+
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
8676
|
+
}
|
8677
|
+
}
|
8678
|
+
|
8679
|
+
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
8680
|
+
#elif defined(__riscv_v_intrinsic)
|
8681
|
+
// todo: RVV implementation
|
8682
|
+
const int np = 0;
|
8683
|
+
#else
|
8684
|
+
const int np = (nc & ~(GGML_F32_STEP - 1));
|
8685
|
+
|
8686
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
8687
|
+
|
8688
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
8689
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
8690
|
+
|
8691
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
8692
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
8693
|
+
GGML_F32_VEC az[GGML_F32_ARR];
|
8694
|
+
|
8695
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
8696
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
8697
|
+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
8698
|
+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
|
8699
|
+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
|
8700
|
+
|
8701
|
+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
8702
|
+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
8703
|
+
|
8704
|
+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
8705
|
+
|
8706
|
+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
8707
|
+
|
8708
|
+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
8709
|
+
}
|
8710
|
+
}
|
8375
8711
|
|
8376
|
-
|
8377
|
-
|
8378
|
-
|
8379
|
-
|
8380
|
-
|
8381
|
-
|
8382
|
-
|
8383
|
-
|
8384
|
-
|
8385
|
-
|
8386
|
-
|
8387
|
-
|
8388
|
-
|
8389
|
-
|
8390
|
-
|
8391
|
-
|
8392
|
-
|
8393
|
-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
8394
|
-
float x_dt = x[i1] * dt_soft_plus;
|
8395
|
-
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
8396
|
-
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
8397
|
-
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
8398
|
-
|
8399
|
-
for (int64_t k = 0; k < nc; k += svcntw()) {
|
8400
|
-
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
8401
|
-
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
|
8402
|
-
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
|
8403
|
-
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
8404
|
-
|
8405
|
-
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
8406
|
-
t1 = exp_ps_sve(svptrue_b32(), t1);
|
8407
|
-
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
8408
|
-
|
8409
|
-
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
8410
|
-
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
8411
|
-
|
8412
|
-
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
8712
|
+
// reduce sum0..sum3 to sum0
|
8713
|
+
GGML_F32_VEC_REDUCE(sumf, sum);
|
8714
|
+
#endif
|
8715
|
+
#else
|
8716
|
+
const int np = 0;
|
8717
|
+
#endif
|
8718
|
+
// d_state
|
8719
|
+
for (int i0 = np; i0 < nc; ++i0) {
|
8720
|
+
const int i = i0 + ii*nc;
|
8721
|
+
const int ig = i0 + g*nc;
|
8722
|
+
// state = prev_state * dA + dB * x
|
8723
|
+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
8724
|
+
// y = rowwise_dotprod(state, C)
|
8725
|
+
sumf += state * C[ig];
|
8726
|
+
s[i] = state;
|
8727
|
+
}
|
8728
|
+
y[ii] = sumf;
|
8413
8729
|
}
|
8414
|
-
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
8415
8730
|
}
|
8416
|
-
}
|
8417
|
-
|
8418
|
-
|
8419
|
-
|
8420
|
-
|
8421
|
-
|
8422
|
-
|
8423
|
-
|
8424
|
-
|
8425
|
-
|
8426
|
-
|
8427
|
-
|
8428
|
-
|
8429
|
-
|
8430
|
-
|
8431
|
-
|
8432
|
-
|
8433
|
-
|
8434
|
-
|
8435
|
-
|
8436
|
-
|
8437
|
-
|
8438
|
-
|
8439
|
-
|
8440
|
-
|
8441
|
-
|
8442
|
-
|
8443
|
-
|
8444
|
-
|
8445
|
-
|
8446
|
-
|
8731
|
+
} else {
|
8732
|
+
// Mamba-1 has an element-wise decay factor for the states
|
8733
|
+
|
8734
|
+
// n_head
|
8735
|
+
for (int h = ih0; h < ih1; ++h) {
|
8736
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
8737
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
8738
|
+
const int g = h / (nh / ng); // repeat_interleave
|
8739
|
+
|
8740
|
+
// dim
|
8741
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
8742
|
+
const int ii = i1 + h*nr;
|
8743
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
8744
|
+
#if defined(__ARM_FEATURE_SVE)
|
8745
|
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
8746
|
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
8747
|
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
8748
|
+
|
8749
|
+
// d_state
|
8750
|
+
// TODO: what happens when (d_state % svcntw()) != 0?
|
8751
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
8752
|
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
8753
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
|
8754
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
|
8755
|
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
8756
|
+
|
8757
|
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
8758
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
8759
|
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
8760
|
+
|
8761
|
+
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
8762
|
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
8763
|
+
|
8764
|
+
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
8765
|
+
}
|
8766
|
+
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
8767
|
+
#else
|
8768
|
+
float sumf = 0.0f;
|
8769
|
+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
8770
|
+
// and also because expf is used within the loop.
|
8771
|
+
// d_state
|
8772
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
8773
|
+
const int i = i0 + ii*nc;
|
8774
|
+
const int ig = i0 + g*nc;
|
8775
|
+
// state = prev_state * dA + dB * x
|
8776
|
+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
8777
|
+
// y = rowwise_dotprod(state, C)
|
8778
|
+
sumf += state * C[ig];
|
8779
|
+
s[i] = state;
|
8780
|
+
}
|
8781
|
+
y[ii] = sumf;
|
8782
|
+
#endif
|
8447
8783
|
}
|
8448
|
-
y[i1] = sumf;
|
8449
8784
|
}
|
8450
8785
|
}
|
8786
|
+
// use the output as the source when it's not the first token-wise iteration
|
8787
|
+
s0 = s;
|
8451
8788
|
}
|
8452
|
-
|
8789
|
+
}
|
8453
8790
|
}
|
8454
8791
|
|
8455
8792
|
void ggml_compute_forward_ssm_scan(
|
@@ -8688,6 +9025,18 @@ void ggml_compute_forward_glu(
|
|
8688
9025
|
{
|
8689
9026
|
ggml_compute_forward_swiglu(params, dst);
|
8690
9027
|
} break;
|
9028
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
9029
|
+
{
|
9030
|
+
ggml_compute_forward_swiglu_oai(params, dst);
|
9031
|
+
} break;
|
9032
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
9033
|
+
{
|
9034
|
+
ggml_compute_forward_geglu_erf(params, dst);
|
9035
|
+
} break;
|
9036
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
9037
|
+
{
|
9038
|
+
ggml_compute_forward_geglu_quick(params, dst);
|
9039
|
+
} break;
|
8691
9040
|
default:
|
8692
9041
|
{
|
8693
9042
|
GGML_ABORT("fatal error");
|
@@ -9283,8 +9632,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
9283
9632
|
int64_t h_stride_2d = head_size * head_size;
|
9284
9633
|
|
9285
9634
|
#if defined(GGML_SIMD)
|
9286
|
-
#if defined(__ARM_FEATURE_SVE)
|
9287
|
-
// scalar Route to scalar implementation //TODO: Write SVE code
|
9635
|
+
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
|
9636
|
+
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
|
9288
9637
|
for (int64_t t = 0; t < T; t++) {
|
9289
9638
|
int64_t t_offset = t * t_stride;
|
9290
9639
|
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
@@ -9732,6 +10081,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
9732
10081
|
const int ir1 = MIN(ir0 + dr, nr);
|
9733
10082
|
|
9734
10083
|
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
10084
|
+
|
9735
10085
|
const float alpha = adamw_params_ptr[0];
|
9736
10086
|
const float beta1 = adamw_params_ptr[1];
|
9737
10087
|
const float beta2 = adamw_params_ptr[2];
|
@@ -9739,7 +10089,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
9739
10089
|
const float wd = adamw_params_ptr[4];
|
9740
10090
|
const float beta1h = adamw_params_ptr[5];
|
9741
10091
|
const float beta2h = adamw_params_ptr[6];
|
9742
|
-
|
10092
|
+
const float keep = 1.f - alpha * wd;
|
9743
10093
|
for (int ir = ir0; ir < ir1; ++ir) {
|
9744
10094
|
const int64_t i03 = ir/(ne02*ne01);
|
9745
10095
|
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
@@ -9762,7 +10112,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
9762
10112
|
// The weight decay is applied independently of the Adam momenta m and v.
|
9763
10113
|
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
9764
10114
|
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
9765
|
-
w[i00] = w[i00]*
|
10115
|
+
w[i00] = w[i00] * keep - alpha * mh / vh;
|
9766
10116
|
}
|
9767
10117
|
}
|
9768
10118
|
}
|
@@ -9784,3 +10134,63 @@ void ggml_compute_forward_opt_step_adamw(
|
|
9784
10134
|
}
|
9785
10135
|
}
|
9786
10136
|
}
|
10137
|
+
|
10138
|
+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
10139
|
+
const ggml_tensor * src0 = dst->src[0];
|
10140
|
+
const ggml_tensor * src0_grad = dst->src[1];
|
10141
|
+
const ggml_tensor * sgd_params = dst->src[2];
|
10142
|
+
|
10143
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
10144
|
+
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
|
10145
|
+
|
10146
|
+
const int ith = params->ith;
|
10147
|
+
const int nth = params->nth;
|
10148
|
+
|
10149
|
+
const int nr = ggml_nrows(src0);
|
10150
|
+
|
10151
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
10152
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
10153
|
+
|
10154
|
+
// rows per thread
|
10155
|
+
const int dr = (nr + nth - 1) / nth;
|
10156
|
+
|
10157
|
+
// row range for this thread
|
10158
|
+
const int ir0 = dr * ith;
|
10159
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
10160
|
+
|
10161
|
+
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
10162
|
+
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
|
10163
|
+
const float alpha = sgd_params_ptr[0];
|
10164
|
+
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
10165
|
+
|
10166
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
10167
|
+
const int64_t i03 = ir / (ne02 * ne01);
|
10168
|
+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
10169
|
+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
10170
|
+
|
10171
|
+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
10172
|
+
|
10173
|
+
float * w = (float *) ((char *) src0->data + offset); // weight
|
10174
|
+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
10175
|
+
|
10176
|
+
for (int i00 = 0; i00 < ne00; ++i00) {
|
10177
|
+
w[i00] = w[i00] * keep - alpha * g[i00];
|
10178
|
+
}
|
10179
|
+
}
|
10180
|
+
}
|
10181
|
+
|
10182
|
+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
|
10183
|
+
const ggml_tensor * src0 = dst->src[0];
|
10184
|
+
|
10185
|
+
switch (src0->type) {
|
10186
|
+
case GGML_TYPE_F32:
|
10187
|
+
{
|
10188
|
+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
10189
|
+
}
|
10190
|
+
break;
|
10191
|
+
default:
|
10192
|
+
{
|
10193
|
+
GGML_ABORT("fatal error - sgd is F32 only");
|
10194
|
+
}
|
10195
|
+
}
|
10196
|
+
}
|