@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#include "ggml-cpu.h"
|
|
4
4
|
#include "ggml-impl.h"
|
|
5
5
|
#include "binary-ops.h"
|
|
6
|
+
#include "ggml.h"
|
|
6
7
|
#include "unary-ops.h"
|
|
7
8
|
#include "vec.h"
|
|
8
9
|
|
|
@@ -108,7 +109,7 @@ static void ggml_compute_forward_dup_f16(
|
|
|
108
109
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
109
110
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
110
111
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
111
|
-
dst_ptr[id] =
|
|
112
|
+
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
|
112
113
|
id++;
|
|
113
114
|
}
|
|
114
115
|
}
|
|
@@ -130,7 +131,7 @@ static void ggml_compute_forward_dup_f16(
|
|
|
130
131
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
131
132
|
|
|
132
133
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
133
|
-
src0_f32[i00] =
|
|
134
|
+
src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
|
134
135
|
}
|
|
135
136
|
|
|
136
137
|
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
@@ -156,7 +157,7 @@ static void ggml_compute_forward_dup_f16(
|
|
|
156
157
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
157
158
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
158
159
|
|
|
159
|
-
dst_ptr[id] =
|
|
160
|
+
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
|
|
160
161
|
id++;
|
|
161
162
|
}
|
|
162
163
|
}
|
|
@@ -267,7 +268,7 @@ static void ggml_compute_forward_dup_f16(
|
|
|
267
268
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
268
269
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
269
270
|
|
|
270
|
-
*(float *) dst_ptr =
|
|
271
|
+
*(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
|
271
272
|
|
|
272
273
|
if (++i10 == ne0) {
|
|
273
274
|
i10 = 0;
|
|
@@ -372,7 +373,7 @@ static void ggml_compute_forward_dup_bf16(
|
|
|
372
373
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
373
374
|
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
374
375
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
375
|
-
dst_ptr[id] =
|
|
376
|
+
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
|
|
376
377
|
id++;
|
|
377
378
|
}
|
|
378
379
|
}
|
|
@@ -473,7 +474,7 @@ static void ggml_compute_forward_dup_bf16(
|
|
|
473
474
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
474
475
|
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
475
476
|
|
|
476
|
-
dst_ptr[id] =
|
|
477
|
+
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
|
|
477
478
|
id++;
|
|
478
479
|
}
|
|
479
480
|
}
|
|
@@ -566,7 +567,7 @@ static void ggml_compute_forward_dup_bf16(
|
|
|
566
567
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
567
568
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
568
569
|
|
|
569
|
-
*(ggml_fp16_t *) dst_ptr =
|
|
570
|
+
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
|
|
570
571
|
|
|
571
572
|
if (++i10 == ne0) {
|
|
572
573
|
i10 = 0;
|
|
@@ -696,24 +697,8 @@ static void ggml_compute_forward_dup_f32(
|
|
|
696
697
|
if (ggml_is_contiguous(dst)) {
|
|
697
698
|
// TODO: simplify
|
|
698
699
|
if (nb00 == sizeof(float)) {
|
|
699
|
-
if (dst->type
|
|
700
|
-
|
|
701
|
-
const size_t rs = ne00 * nb00;
|
|
702
|
-
char * dst_ptr = (char *) dst->data;
|
|
703
|
-
|
|
704
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
705
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
706
|
-
id += rs * ir0;
|
|
707
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
708
|
-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
|
709
|
-
memcpy(dst_ptr + id, src0_ptr, rs);
|
|
710
|
-
id += rs;
|
|
711
|
-
}
|
|
712
|
-
id += rs * (ne01 - ir1);
|
|
713
|
-
}
|
|
714
|
-
}
|
|
715
|
-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
716
|
-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
700
|
+
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
701
|
+
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
717
702
|
|
|
718
703
|
size_t id = 0;
|
|
719
704
|
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
@@ -724,7 +709,7 @@ static void ggml_compute_forward_dup_f32(
|
|
|
724
709
|
id += rs * ir0;
|
|
725
710
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
726
711
|
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
727
|
-
|
|
712
|
+
from_float(src0_ptr, dst_ptr + id, ne00);
|
|
728
713
|
id += rs;
|
|
729
714
|
}
|
|
730
715
|
id += rs * (ne01 - ir1);
|
|
@@ -765,7 +750,7 @@ static void ggml_compute_forward_dup_f32(
|
|
|
765
750
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
766
751
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
767
752
|
|
|
768
|
-
dst_ptr[id] =
|
|
753
|
+
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
|
|
769
754
|
id++;
|
|
770
755
|
}
|
|
771
756
|
}
|
|
@@ -878,7 +863,7 @@ static void ggml_compute_forward_dup_f32(
|
|
|
878
863
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
879
864
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
880
865
|
|
|
881
|
-
*(ggml_fp16_t *) dst_ptr =
|
|
866
|
+
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
|
|
882
867
|
|
|
883
868
|
if (++i10 == ne0) {
|
|
884
869
|
i10 = 0;
|
|
@@ -1419,7 +1404,7 @@ static void ggml_compute_forward_add1_f16_f32(
|
|
|
1419
1404
|
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
|
1420
1405
|
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
1421
1406
|
for (int i = 0; i < ne0; i++) {
|
|
1422
|
-
dst_ptr[i] =
|
|
1407
|
+
dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
|
|
1423
1408
|
}
|
|
1424
1409
|
}
|
|
1425
1410
|
}
|
|
@@ -1435,7 +1420,7 @@ static void ggml_compute_forward_add1_f16_f16(
|
|
|
1435
1420
|
GGML_ASSERT(ggml_is_scalar(src1));
|
|
1436
1421
|
|
|
1437
1422
|
// scalar to add
|
|
1438
|
-
const float v =
|
|
1423
|
+
const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
|
|
1439
1424
|
|
|
1440
1425
|
const int ith = params->ith;
|
|
1441
1426
|
const int nth = params->nth;
|
|
@@ -1467,7 +1452,7 @@ static void ggml_compute_forward_add1_f16_f16(
|
|
|
1467
1452
|
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
|
1468
1453
|
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
1469
1454
|
for (int i = 0; i < ne0; i++) {
|
|
1470
|
-
dst_ptr[i] =
|
|
1455
|
+
dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
|
|
1471
1456
|
}
|
|
1472
1457
|
}
|
|
1473
1458
|
}
|
|
@@ -1889,7 +1874,7 @@ static void ggml_compute_forward_sum_f16(
|
|
|
1889
1874
|
}
|
|
1890
1875
|
}
|
|
1891
1876
|
}
|
|
1892
|
-
((ggml_fp16_t *) dst->data)[0] =
|
|
1877
|
+
((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
|
|
1893
1878
|
}
|
|
1894
1879
|
|
|
1895
1880
|
static void ggml_compute_forward_sum_bf16(
|
|
@@ -2300,6 +2285,12 @@ void ggml_compute_forward_repeat(
|
|
|
2300
2285
|
{
|
|
2301
2286
|
ggml_compute_forward_repeat_f32(params, dst);
|
|
2302
2287
|
} break;
|
|
2288
|
+
// TODO: templateify the implemenation and support for I64
|
|
2289
|
+
// ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
|
|
2290
|
+
//case GGML_TYPE_I64:
|
|
2291
|
+
// {
|
|
2292
|
+
// ggml_compute_forward_repeat_i64(params, dst);
|
|
2293
|
+
// } break;
|
|
2303
2294
|
default:
|
|
2304
2295
|
{
|
|
2305
2296
|
GGML_ABORT("fatal error");
|
|
@@ -2660,7 +2651,7 @@ static void ggml_compute_forward_gelu_f16(
|
|
|
2660
2651
|
#ifndef NDEBUG
|
|
2661
2652
|
for (int k = 0; k < nc; k++) {
|
|
2662
2653
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2663
|
-
const float v =
|
|
2654
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2664
2655
|
GGML_UNUSED(v);
|
|
2665
2656
|
assert(!isnan(v));
|
|
2666
2657
|
assert(!isinf(v));
|
|
@@ -2763,7 +2754,7 @@ static void ggml_compute_forward_gelu_erf_f16(
|
|
|
2763
2754
|
#ifndef NDEBUG
|
|
2764
2755
|
for (int k = 0; k < nc; k++) {
|
|
2765
2756
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2766
|
-
const float v =
|
|
2757
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2767
2758
|
GGML_UNUSED(v);
|
|
2768
2759
|
assert(!isnan(v));
|
|
2769
2760
|
assert(!isinf(v));
|
|
@@ -2866,7 +2857,7 @@ static void ggml_compute_forward_gelu_quick_f16(
|
|
|
2866
2857
|
#ifndef NDEBUG
|
|
2867
2858
|
for (int k = 0; k < nc; k++) {
|
|
2868
2859
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
2869
|
-
const float v =
|
|
2860
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2870
2861
|
GGML_UNUSED(v);
|
|
2871
2862
|
assert(!isnan(v));
|
|
2872
2863
|
assert(!isinf(v));
|
|
@@ -2969,7 +2960,7 @@ static void ggml_compute_forward_silu_f16(
|
|
|
2969
2960
|
#ifndef NDEBUG
|
|
2970
2961
|
for (int k = 0; k < nc; k++) {
|
|
2971
2962
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
|
2972
|
-
const float v =
|
|
2963
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
2973
2964
|
GGML_UNUSED(v);
|
|
2974
2965
|
assert(!isnan(v));
|
|
2975
2966
|
assert(!isinf(v));
|
|
@@ -3144,8 +3135,718 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
3144
3135
|
const int ith = params->ith;
|
|
3145
3136
|
const int nth = params->nth;
|
|
3146
3137
|
|
|
3147
|
-
const int nc = src1->ne[0];
|
|
3148
|
-
const int nr = ggml_nrows(src1);
|
|
3138
|
+
const int nc = src1->ne[0];
|
|
3139
|
+
const int nr = ggml_nrows(src1);
|
|
3140
|
+
|
|
3141
|
+
// rows per thread
|
|
3142
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3143
|
+
|
|
3144
|
+
// row range for this thread
|
|
3145
|
+
const int ir0 = dr*ith;
|
|
3146
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3147
|
+
|
|
3148
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3149
|
+
ggml_vec_silu_backward_f16(nc,
|
|
3150
|
+
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
|
3151
|
+
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
3152
|
+
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
3153
|
+
|
|
3154
|
+
#ifndef NDEBUG
|
|
3155
|
+
for (int k = 0; k < nc; k++) {
|
|
3156
|
+
const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3157
|
+
const float v = GGML_CPU_FP16_TO_FP32(x);
|
|
3158
|
+
GGML_UNUSED(v);
|
|
3159
|
+
assert(!isnan(v));
|
|
3160
|
+
assert(!isinf(v));
|
|
3161
|
+
}
|
|
3162
|
+
#endif
|
|
3163
|
+
}
|
|
3164
|
+
}
|
|
3165
|
+
|
|
3166
|
+
void ggml_compute_forward_silu_back(
|
|
3167
|
+
const ggml_compute_params * params,
|
|
3168
|
+
ggml_tensor * dst) {
|
|
3169
|
+
|
|
3170
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3171
|
+
|
|
3172
|
+
switch (src0->type) {
|
|
3173
|
+
case GGML_TYPE_F32:
|
|
3174
|
+
{
|
|
3175
|
+
ggml_compute_forward_silu_back_f32(params, dst);
|
|
3176
|
+
} break;
|
|
3177
|
+
case GGML_TYPE_F16:
|
|
3178
|
+
{
|
|
3179
|
+
ggml_compute_forward_silu_back_f16(params, dst);
|
|
3180
|
+
} break;
|
|
3181
|
+
default:
|
|
3182
|
+
{
|
|
3183
|
+
GGML_ABORT("fatal error");
|
|
3184
|
+
}
|
|
3185
|
+
}
|
|
3186
|
+
}
|
|
3187
|
+
|
|
3188
|
+
// ggml_compute_forward_reglu
|
|
3189
|
+
|
|
3190
|
+
static void ggml_compute_forward_reglu_f32(
|
|
3191
|
+
const ggml_compute_params * params,
|
|
3192
|
+
ggml_tensor * dst) {
|
|
3193
|
+
|
|
3194
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3195
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3196
|
+
char * src0_d = (char *) src0->data;
|
|
3197
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3198
|
+
const size_t src0_o = src0->nb[1];
|
|
3199
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3200
|
+
|
|
3201
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3202
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3203
|
+
|
|
3204
|
+
if (src1) {
|
|
3205
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3206
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3207
|
+
}
|
|
3208
|
+
|
|
3209
|
+
const int ith = params->ith;
|
|
3210
|
+
const int nth = params->nth;
|
|
3211
|
+
|
|
3212
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3213
|
+
const int nr = ggml_nrows(src0);
|
|
3214
|
+
|
|
3215
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3216
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3217
|
+
|
|
3218
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3219
|
+
|
|
3220
|
+
// rows per thread
|
|
3221
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3222
|
+
|
|
3223
|
+
// row range for this thread
|
|
3224
|
+
const int ir0 = dr*ith;
|
|
3225
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3226
|
+
|
|
3227
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3228
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3229
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3230
|
+
|
|
3231
|
+
if (!src1) {
|
|
3232
|
+
src0_p += swapped ? nc : 0;
|
|
3233
|
+
src1_p += swapped ? 0 : nc;
|
|
3234
|
+
}
|
|
3235
|
+
|
|
3236
|
+
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3237
|
+
|
|
3238
|
+
#ifndef NDEBUG
|
|
3239
|
+
for (int k = 0; k < nc; k++) {
|
|
3240
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3241
|
+
GGML_UNUSED(x);
|
|
3242
|
+
assert(!isnan(x));
|
|
3243
|
+
assert(!isinf(x));
|
|
3244
|
+
}
|
|
3245
|
+
#endif
|
|
3246
|
+
}
|
|
3247
|
+
}
|
|
3248
|
+
|
|
3249
|
+
static void ggml_compute_forward_reglu_f16(
|
|
3250
|
+
const ggml_compute_params * params,
|
|
3251
|
+
ggml_tensor * dst) {
|
|
3252
|
+
|
|
3253
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3254
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3255
|
+
char * src0_d = (char *) src0->data;
|
|
3256
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3257
|
+
const size_t src0_o = src0->nb[1];
|
|
3258
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3259
|
+
|
|
3260
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3261
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3262
|
+
|
|
3263
|
+
if (src1) {
|
|
3264
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3265
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3266
|
+
}
|
|
3267
|
+
|
|
3268
|
+
const int ith = params->ith;
|
|
3269
|
+
const int nth = params->nth;
|
|
3270
|
+
|
|
3271
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3272
|
+
const int nr = ggml_nrows(src0);
|
|
3273
|
+
|
|
3274
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3275
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3276
|
+
|
|
3277
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3278
|
+
|
|
3279
|
+
// rows per thread
|
|
3280
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3281
|
+
|
|
3282
|
+
// row range for this thread
|
|
3283
|
+
const int ir0 = dr*ith;
|
|
3284
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3285
|
+
|
|
3286
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3287
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3288
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3289
|
+
|
|
3290
|
+
if (!src1) {
|
|
3291
|
+
src0_p += swapped ? nc : 0;
|
|
3292
|
+
src1_p += swapped ? 0 : nc;
|
|
3293
|
+
}
|
|
3294
|
+
|
|
3295
|
+
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3296
|
+
|
|
3297
|
+
#ifndef NDEBUG
|
|
3298
|
+
for (int k = 0; k < nc; k++) {
|
|
3299
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3300
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3301
|
+
GGML_UNUSED(v);
|
|
3302
|
+
assert(!isnan(v));
|
|
3303
|
+
assert(!isinf(v));
|
|
3304
|
+
}
|
|
3305
|
+
#endif
|
|
3306
|
+
}
|
|
3307
|
+
}
|
|
3308
|
+
|
|
3309
|
+
static void ggml_compute_forward_reglu(
|
|
3310
|
+
const ggml_compute_params * params,
|
|
3311
|
+
ggml_tensor * dst) {
|
|
3312
|
+
|
|
3313
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3314
|
+
|
|
3315
|
+
switch (src0->type) {
|
|
3316
|
+
case GGML_TYPE_F32:
|
|
3317
|
+
{
|
|
3318
|
+
ggml_compute_forward_reglu_f32(params, dst);
|
|
3319
|
+
} break;
|
|
3320
|
+
case GGML_TYPE_F16:
|
|
3321
|
+
{
|
|
3322
|
+
ggml_compute_forward_reglu_f16(params, dst);
|
|
3323
|
+
} break;
|
|
3324
|
+
default:
|
|
3325
|
+
{
|
|
3326
|
+
GGML_ABORT("fatal error");
|
|
3327
|
+
}
|
|
3328
|
+
}
|
|
3329
|
+
}
|
|
3330
|
+
|
|
3331
|
+
// ggml_compute_forward_geglu
|
|
3332
|
+
|
|
3333
|
+
static void ggml_compute_forward_geglu_f32(
|
|
3334
|
+
const ggml_compute_params * params,
|
|
3335
|
+
ggml_tensor * dst) {
|
|
3336
|
+
|
|
3337
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3338
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3339
|
+
char * src0_d = (char *) src0->data;
|
|
3340
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3341
|
+
const size_t src0_o = src0->nb[1];
|
|
3342
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3343
|
+
|
|
3344
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3345
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3346
|
+
|
|
3347
|
+
if (src1) {
|
|
3348
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3349
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3350
|
+
}
|
|
3351
|
+
|
|
3352
|
+
const int ith = params->ith;
|
|
3353
|
+
const int nth = params->nth;
|
|
3354
|
+
|
|
3355
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3356
|
+
const int nr = ggml_nrows(src0);
|
|
3357
|
+
|
|
3358
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3359
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3360
|
+
|
|
3361
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3362
|
+
|
|
3363
|
+
// rows per thread
|
|
3364
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3365
|
+
|
|
3366
|
+
// row range for this thread
|
|
3367
|
+
const int ir0 = dr*ith;
|
|
3368
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3369
|
+
|
|
3370
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3371
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3372
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3373
|
+
|
|
3374
|
+
if (!src1) {
|
|
3375
|
+
src0_p += swapped ? nc : 0;
|
|
3376
|
+
src1_p += swapped ? 0 : nc;
|
|
3377
|
+
}
|
|
3378
|
+
|
|
3379
|
+
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3380
|
+
|
|
3381
|
+
#ifndef NDEBUG
|
|
3382
|
+
for (int k = 0; k < nc; k++) {
|
|
3383
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3384
|
+
GGML_UNUSED(x);
|
|
3385
|
+
assert(!isnan(x));
|
|
3386
|
+
assert(!isinf(x));
|
|
3387
|
+
}
|
|
3388
|
+
#endif
|
|
3389
|
+
}
|
|
3390
|
+
}
|
|
3391
|
+
|
|
3392
|
+
static void ggml_compute_forward_geglu_f16(
|
|
3393
|
+
const ggml_compute_params * params,
|
|
3394
|
+
ggml_tensor * dst) {
|
|
3395
|
+
|
|
3396
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3397
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3398
|
+
char * src0_d = (char *) src0->data;
|
|
3399
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3400
|
+
const size_t src0_o = src0->nb[1];
|
|
3401
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3402
|
+
|
|
3403
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3404
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3405
|
+
|
|
3406
|
+
if (src1) {
|
|
3407
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3408
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3409
|
+
}
|
|
3410
|
+
|
|
3411
|
+
const int ith = params->ith;
|
|
3412
|
+
const int nth = params->nth;
|
|
3413
|
+
|
|
3414
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3415
|
+
const int nr = ggml_nrows(src0);
|
|
3416
|
+
|
|
3417
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3418
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3419
|
+
|
|
3420
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3421
|
+
|
|
3422
|
+
// rows per thread
|
|
3423
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3424
|
+
|
|
3425
|
+
// row range for this thread
|
|
3426
|
+
const int ir0 = dr*ith;
|
|
3427
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3428
|
+
|
|
3429
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3430
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3431
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3432
|
+
|
|
3433
|
+
if (!src1) {
|
|
3434
|
+
src0_p += swapped ? nc : 0;
|
|
3435
|
+
src1_p += swapped ? 0 : nc;
|
|
3436
|
+
}
|
|
3437
|
+
|
|
3438
|
+
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3439
|
+
|
|
3440
|
+
#ifndef NDEBUG
|
|
3441
|
+
for (int k = 0; k < nc; k++) {
|
|
3442
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3443
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3444
|
+
GGML_UNUSED(v);
|
|
3445
|
+
assert(!isnan(v));
|
|
3446
|
+
assert(!isinf(v));
|
|
3447
|
+
}
|
|
3448
|
+
#endif
|
|
3449
|
+
}
|
|
3450
|
+
}
|
|
3451
|
+
|
|
3452
|
+
static void ggml_compute_forward_geglu(
|
|
3453
|
+
const ggml_compute_params * params,
|
|
3454
|
+
ggml_tensor * dst) {
|
|
3455
|
+
|
|
3456
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3457
|
+
|
|
3458
|
+
switch (src0->type) {
|
|
3459
|
+
case GGML_TYPE_F32:
|
|
3460
|
+
{
|
|
3461
|
+
ggml_compute_forward_geglu_f32(params, dst);
|
|
3462
|
+
} break;
|
|
3463
|
+
case GGML_TYPE_F16:
|
|
3464
|
+
{
|
|
3465
|
+
ggml_compute_forward_geglu_f16(params, dst);
|
|
3466
|
+
} break;
|
|
3467
|
+
default:
|
|
3468
|
+
{
|
|
3469
|
+
GGML_ABORT("fatal error");
|
|
3470
|
+
}
|
|
3471
|
+
}
|
|
3472
|
+
}
|
|
3473
|
+
|
|
3474
|
+
// ggml_compute_forward_swiglu
|
|
3475
|
+
|
|
3476
|
+
static void ggml_compute_forward_swiglu_f32(
|
|
3477
|
+
const ggml_compute_params * params,
|
|
3478
|
+
ggml_tensor * dst) {
|
|
3479
|
+
|
|
3480
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3481
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3482
|
+
char * src0_d = (char *) src0->data;
|
|
3483
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3484
|
+
const size_t src0_o = src0->nb[1];
|
|
3485
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3486
|
+
|
|
3487
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3488
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3489
|
+
|
|
3490
|
+
if (src1) {
|
|
3491
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3492
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3493
|
+
}
|
|
3494
|
+
|
|
3495
|
+
const int ith = params->ith;
|
|
3496
|
+
const int nth = params->nth;
|
|
3497
|
+
|
|
3498
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3499
|
+
const int nr = ggml_nrows(src0);
|
|
3500
|
+
|
|
3501
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3502
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3503
|
+
|
|
3504
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3505
|
+
|
|
3506
|
+
// rows per thread
|
|
3507
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3508
|
+
|
|
3509
|
+
// row range for this thread
|
|
3510
|
+
const int ir0 = dr*ith;
|
|
3511
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3512
|
+
|
|
3513
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3514
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3515
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3516
|
+
|
|
3517
|
+
if (!src1) {
|
|
3518
|
+
src0_p += swapped ? nc : 0;
|
|
3519
|
+
src1_p += swapped ? 0 : nc;
|
|
3520
|
+
}
|
|
3521
|
+
|
|
3522
|
+
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3523
|
+
|
|
3524
|
+
#ifndef NDEBUG
|
|
3525
|
+
for (int k = 0; k < nc; k++) {
|
|
3526
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3527
|
+
GGML_UNUSED(x);
|
|
3528
|
+
assert(!isnan(x));
|
|
3529
|
+
assert(!isinf(x));
|
|
3530
|
+
}
|
|
3531
|
+
#endif
|
|
3532
|
+
}
|
|
3533
|
+
}
|
|
3534
|
+
|
|
3535
|
+
static void ggml_compute_forward_swiglu_f16(
|
|
3536
|
+
const ggml_compute_params * params,
|
|
3537
|
+
ggml_tensor * dst) {
|
|
3538
|
+
|
|
3539
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3540
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3541
|
+
char * src0_d = (char *) src0->data;
|
|
3542
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3543
|
+
const size_t src0_o = src0->nb[1];
|
|
3544
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3545
|
+
|
|
3546
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3547
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3548
|
+
|
|
3549
|
+
if (src1) {
|
|
3550
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3551
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3552
|
+
}
|
|
3553
|
+
|
|
3554
|
+
const int ith = params->ith;
|
|
3555
|
+
const int nth = params->nth;
|
|
3556
|
+
|
|
3557
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3558
|
+
const int nr = ggml_nrows(src0);
|
|
3559
|
+
|
|
3560
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3561
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3562
|
+
|
|
3563
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3564
|
+
|
|
3565
|
+
// rows per thread
|
|
3566
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3567
|
+
|
|
3568
|
+
// row range for this thread
|
|
3569
|
+
const int ir0 = dr*ith;
|
|
3570
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3571
|
+
|
|
3572
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3573
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3574
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3575
|
+
|
|
3576
|
+
if (!src1) {
|
|
3577
|
+
src0_p += swapped ? nc : 0;
|
|
3578
|
+
src1_p += swapped ? 0 : nc;
|
|
3579
|
+
}
|
|
3580
|
+
|
|
3581
|
+
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3582
|
+
|
|
3583
|
+
#ifndef NDEBUG
|
|
3584
|
+
for (int k = 0; k < nc; k++) {
|
|
3585
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3586
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3587
|
+
GGML_UNUSED(v);
|
|
3588
|
+
assert(!isnan(v));
|
|
3589
|
+
assert(!isinf(v));
|
|
3590
|
+
}
|
|
3591
|
+
#endif
|
|
3592
|
+
}
|
|
3593
|
+
}
|
|
3594
|
+
|
|
3595
|
+
static void ggml_compute_forward_swiglu(
|
|
3596
|
+
const ggml_compute_params * params,
|
|
3597
|
+
ggml_tensor * dst) {
|
|
3598
|
+
|
|
3599
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3600
|
+
|
|
3601
|
+
switch (src0->type) {
|
|
3602
|
+
case GGML_TYPE_F32:
|
|
3603
|
+
{
|
|
3604
|
+
ggml_compute_forward_swiglu_f32(params, dst);
|
|
3605
|
+
} break;
|
|
3606
|
+
case GGML_TYPE_F16:
|
|
3607
|
+
{
|
|
3608
|
+
ggml_compute_forward_swiglu_f16(params, dst);
|
|
3609
|
+
} break;
|
|
3610
|
+
default:
|
|
3611
|
+
{
|
|
3612
|
+
GGML_ABORT("fatal error");
|
|
3613
|
+
}
|
|
3614
|
+
}
|
|
3615
|
+
}
|
|
3616
|
+
|
|
3617
|
+
// ggml_compute_forward_geglu_erf
|
|
3618
|
+
|
|
3619
|
+
static void ggml_compute_forward_geglu_erf_f32(
|
|
3620
|
+
const ggml_compute_params * params,
|
|
3621
|
+
ggml_tensor * dst) {
|
|
3622
|
+
|
|
3623
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3624
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3625
|
+
char * src0_d = (char *) src0->data;
|
|
3626
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3627
|
+
const size_t src0_o = src0->nb[1];
|
|
3628
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3629
|
+
|
|
3630
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3631
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3632
|
+
|
|
3633
|
+
if (src1) {
|
|
3634
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3635
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3636
|
+
}
|
|
3637
|
+
|
|
3638
|
+
const int ith = params->ith;
|
|
3639
|
+
const int nth = params->nth;
|
|
3640
|
+
|
|
3641
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3642
|
+
const int nr = ggml_nrows(src0);
|
|
3643
|
+
|
|
3644
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3645
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3646
|
+
|
|
3647
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3648
|
+
|
|
3649
|
+
// rows per thread
|
|
3650
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3651
|
+
|
|
3652
|
+
// row range for this thread
|
|
3653
|
+
const int ir0 = dr*ith;
|
|
3654
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3655
|
+
|
|
3656
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3657
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3658
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3659
|
+
|
|
3660
|
+
if (!src1) {
|
|
3661
|
+
src0_p += swapped ? nc : 0;
|
|
3662
|
+
src1_p += swapped ? 0 : nc;
|
|
3663
|
+
}
|
|
3664
|
+
|
|
3665
|
+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3666
|
+
|
|
3667
|
+
#ifndef NDEBUG
|
|
3668
|
+
for (int k = 0; k < nc; k++) {
|
|
3669
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3670
|
+
GGML_UNUSED(x);
|
|
3671
|
+
assert(!isnan(x));
|
|
3672
|
+
assert(!isinf(x));
|
|
3673
|
+
}
|
|
3674
|
+
#endif
|
|
3675
|
+
}
|
|
3676
|
+
}
|
|
3677
|
+
|
|
3678
|
+
static void ggml_compute_forward_geglu_erf_f16(
|
|
3679
|
+
const ggml_compute_params * params,
|
|
3680
|
+
ggml_tensor * dst) {
|
|
3681
|
+
|
|
3682
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3683
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3684
|
+
char * src0_d = (char *) src0->data;
|
|
3685
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3686
|
+
const size_t src0_o = src0->nb[1];
|
|
3687
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3688
|
+
|
|
3689
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3690
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3691
|
+
|
|
3692
|
+
if (src1) {
|
|
3693
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3694
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3695
|
+
}
|
|
3696
|
+
|
|
3697
|
+
const int ith = params->ith;
|
|
3698
|
+
const int nth = params->nth;
|
|
3699
|
+
|
|
3700
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3701
|
+
const int nr = ggml_nrows(src0);
|
|
3702
|
+
|
|
3703
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3704
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3705
|
+
|
|
3706
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3707
|
+
|
|
3708
|
+
// rows per thread
|
|
3709
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3710
|
+
|
|
3711
|
+
// row range for this thread
|
|
3712
|
+
const int ir0 = dr*ith;
|
|
3713
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3714
|
+
|
|
3715
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3716
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3717
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3718
|
+
|
|
3719
|
+
if (!src1) {
|
|
3720
|
+
src0_p += swapped ? nc : 0;
|
|
3721
|
+
src1_p += swapped ? 0 : nc;
|
|
3722
|
+
}
|
|
3723
|
+
|
|
3724
|
+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3725
|
+
|
|
3726
|
+
#ifndef NDEBUG
|
|
3727
|
+
for (int k = 0; k < nc; k++) {
|
|
3728
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3729
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3730
|
+
GGML_UNUSED(v);
|
|
3731
|
+
assert(!isnan(v));
|
|
3732
|
+
assert(!isinf(v));
|
|
3733
|
+
}
|
|
3734
|
+
#endif
|
|
3735
|
+
}
|
|
3736
|
+
}
|
|
3737
|
+
|
|
3738
|
+
static void ggml_compute_forward_geglu_erf(
|
|
3739
|
+
const ggml_compute_params * params,
|
|
3740
|
+
ggml_tensor * dst) {
|
|
3741
|
+
|
|
3742
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3743
|
+
|
|
3744
|
+
switch (src0->type) {
|
|
3745
|
+
case GGML_TYPE_F32:
|
|
3746
|
+
{
|
|
3747
|
+
ggml_compute_forward_geglu_erf_f32(params, dst);
|
|
3748
|
+
} break;
|
|
3749
|
+
case GGML_TYPE_F16:
|
|
3750
|
+
{
|
|
3751
|
+
ggml_compute_forward_geglu_erf_f16(params, dst);
|
|
3752
|
+
} break;
|
|
3753
|
+
default:
|
|
3754
|
+
{
|
|
3755
|
+
GGML_ABORT("fatal error");
|
|
3756
|
+
}
|
|
3757
|
+
}
|
|
3758
|
+
}
|
|
3759
|
+
|
|
3760
|
+
// ggml_compute_forward_geglu_quick
|
|
3761
|
+
|
|
3762
|
+
static void ggml_compute_forward_geglu_quick_f32(
|
|
3763
|
+
const ggml_compute_params * params,
|
|
3764
|
+
ggml_tensor * dst) {
|
|
3765
|
+
|
|
3766
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3767
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3768
|
+
char * src0_d = (char *) src0->data;
|
|
3769
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3770
|
+
const size_t src0_o = src0->nb[1];
|
|
3771
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3772
|
+
|
|
3773
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3774
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3775
|
+
|
|
3776
|
+
if (src1) {
|
|
3777
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3778
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3779
|
+
}
|
|
3780
|
+
|
|
3781
|
+
const int ith = params->ith;
|
|
3782
|
+
const int nth = params->nth;
|
|
3783
|
+
|
|
3784
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3785
|
+
const int nr = ggml_nrows(src0);
|
|
3786
|
+
|
|
3787
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3788
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3789
|
+
|
|
3790
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3791
|
+
|
|
3792
|
+
// rows per thread
|
|
3793
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3794
|
+
|
|
3795
|
+
// row range for this thread
|
|
3796
|
+
const int ir0 = dr*ith;
|
|
3797
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3798
|
+
|
|
3799
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3800
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3801
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3802
|
+
|
|
3803
|
+
if (!src1) {
|
|
3804
|
+
src0_p += swapped ? nc : 0;
|
|
3805
|
+
src1_p += swapped ? 0 : nc;
|
|
3806
|
+
}
|
|
3807
|
+
|
|
3808
|
+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3809
|
+
|
|
3810
|
+
#ifndef NDEBUG
|
|
3811
|
+
for (int k = 0; k < nc; k++) {
|
|
3812
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3813
|
+
GGML_UNUSED(x);
|
|
3814
|
+
assert(!isnan(x));
|
|
3815
|
+
assert(!isinf(x));
|
|
3816
|
+
}
|
|
3817
|
+
#endif
|
|
3818
|
+
}
|
|
3819
|
+
}
|
|
3820
|
+
|
|
3821
|
+
static void ggml_compute_forward_geglu_quick_f16(
|
|
3822
|
+
const ggml_compute_params * params,
|
|
3823
|
+
ggml_tensor * dst) {
|
|
3824
|
+
|
|
3825
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3826
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3827
|
+
char * src0_d = (char *) src0->data;
|
|
3828
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3829
|
+
const size_t src0_o = src0->nb[1];
|
|
3830
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3831
|
+
|
|
3832
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3833
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3834
|
+
|
|
3835
|
+
if (src1) {
|
|
3836
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3837
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3838
|
+
}
|
|
3839
|
+
|
|
3840
|
+
const int ith = params->ith;
|
|
3841
|
+
const int nth = params->nth;
|
|
3842
|
+
|
|
3843
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3844
|
+
const int nr = ggml_nrows(src0);
|
|
3845
|
+
|
|
3846
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3847
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3848
|
+
|
|
3849
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3149
3850
|
|
|
3150
3851
|
// rows per thread
|
|
3151
3852
|
const int dr = (nr + nth - 1)/nth;
|
|
@@ -3155,24 +3856,29 @@ static void ggml_compute_forward_silu_back_f16(
|
|
|
3155
3856
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
3156
3857
|
|
|
3157
3858
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3158
|
-
|
|
3159
|
-
|
|
3160
|
-
(ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
|
|
3161
|
-
(ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
|
|
3859
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3860
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3162
3861
|
|
|
3163
|
-
|
|
3862
|
+
if (!src1) {
|
|
3863
|
+
src0_p += swapped ? nc : 0;
|
|
3864
|
+
src1_p += swapped ? 0 : nc;
|
|
3865
|
+
}
|
|
3866
|
+
|
|
3867
|
+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3868
|
+
|
|
3869
|
+
#ifndef NDEBUG
|
|
3164
3870
|
for (int k = 0; k < nc; k++) {
|
|
3165
|
-
const
|
|
3871
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3166
3872
|
const float v = GGML_FP16_TO_FP32(x);
|
|
3167
3873
|
GGML_UNUSED(v);
|
|
3168
3874
|
assert(!isnan(v));
|
|
3169
3875
|
assert(!isinf(v));
|
|
3170
3876
|
}
|
|
3171
|
-
|
|
3877
|
+
#endif
|
|
3172
3878
|
}
|
|
3173
3879
|
}
|
|
3174
3880
|
|
|
3175
|
-
void
|
|
3881
|
+
static void ggml_compute_forward_geglu_quick(
|
|
3176
3882
|
const ggml_compute_params * params,
|
|
3177
3883
|
ggml_tensor * dst) {
|
|
3178
3884
|
|
|
@@ -3181,11 +3887,11 @@ void ggml_compute_forward_silu_back(
|
|
|
3181
3887
|
switch (src0->type) {
|
|
3182
3888
|
case GGML_TYPE_F32:
|
|
3183
3889
|
{
|
|
3184
|
-
|
|
3890
|
+
ggml_compute_forward_geglu_quick_f32(params, dst);
|
|
3185
3891
|
} break;
|
|
3186
3892
|
case GGML_TYPE_F16:
|
|
3187
3893
|
{
|
|
3188
|
-
|
|
3894
|
+
ggml_compute_forward_geglu_quick_f16(params, dst);
|
|
3189
3895
|
} break;
|
|
3190
3896
|
default:
|
|
3191
3897
|
{
|
|
@@ -3937,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
|
|
|
3937
4643
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
3938
4644
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
3939
4645
|
|
|
3940
|
-
// scale factor
|
|
3941
|
-
float
|
|
3942
|
-
|
|
4646
|
+
float s; // scale factor
|
|
4647
|
+
float b; // bias
|
|
4648
|
+
|
|
4649
|
+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
|
|
4650
|
+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
|
|
3943
4651
|
|
|
3944
4652
|
const int ith = params->ith;
|
|
3945
4653
|
const int nth = params->nth;
|
|
@@ -3958,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(
|
|
|
3958
4666
|
|
|
3959
4667
|
const size_t nb1 = dst->nb[1];
|
|
3960
4668
|
|
|
3961
|
-
|
|
3962
|
-
|
|
3963
|
-
|
|
3964
|
-
|
|
4669
|
+
if (b == 0.0f) {
|
|
4670
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4671
|
+
if (dst->data != src0->data) {
|
|
4672
|
+
// src0 is same shape as dst => same indices
|
|
4673
|
+
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
|
|
4674
|
+
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
|
4675
|
+
}
|
|
4676
|
+
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
|
4677
|
+
}
|
|
4678
|
+
} else {
|
|
4679
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4680
|
+
ggml_vec_mad1_f32(nc,
|
|
4681
|
+
(float *) ((char *) dst->data + i1*nb1),
|
|
4682
|
+
(float *) ((char *) src0->data + i1*nb1),
|
|
4683
|
+
s, b);
|
|
3965
4684
|
}
|
|
3966
|
-
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
|
3967
4685
|
}
|
|
3968
4686
|
}
|
|
3969
4687
|
|
|
@@ -4470,6 +5188,74 @@ void ggml_compute_forward_get_rows(
|
|
|
4470
5188
|
//}
|
|
4471
5189
|
}
|
|
4472
5190
|
|
|
5191
|
+
static void ggml_compute_forward_set_rows_f32(
|
|
5192
|
+
const ggml_compute_params * params,
|
|
5193
|
+
ggml_tensor * dst) {
|
|
5194
|
+
|
|
5195
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
5196
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
5197
|
+
|
|
5198
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
5199
|
+
|
|
5200
|
+
const int64_t nc = ne00;
|
|
5201
|
+
const int64_t nr = ne01;
|
|
5202
|
+
|
|
5203
|
+
assert(ne0 == nc);
|
|
5204
|
+
assert(ne2 == ne02);
|
|
5205
|
+
assert(ne3 == ne03);
|
|
5206
|
+
assert(src0->type == GGML_TYPE_F32);
|
|
5207
|
+
assert(ne02 % ne11 == 0);
|
|
5208
|
+
assert(ne03 % ne12 == 0);
|
|
5209
|
+
|
|
5210
|
+
const int ith = params->ith;
|
|
5211
|
+
const int nth = params->nth;
|
|
5212
|
+
|
|
5213
|
+
// rows per thread
|
|
5214
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
|
5215
|
+
|
|
5216
|
+
// row range for this thread
|
|
5217
|
+
const int64_t ir0 = dr*ith;
|
|
5218
|
+
const int64_t ir1 = std::min(ir0 + dr, nr);
|
|
5219
|
+
|
|
5220
|
+
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
5221
|
+
|
|
5222
|
+
for (int64_t i03 = 0; i03 < ne03; ++i03) {
|
|
5223
|
+
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
|
5224
|
+
for (int64_t i = ir0; i < ir1; ++i) {
|
|
5225
|
+
const int64_t i12 = i03%ne12;
|
|
5226
|
+
const int64_t i11 = i02%ne11;
|
|
5227
|
+
const int64_t i10 = i;
|
|
5228
|
+
|
|
5229
|
+
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
|
5230
|
+
|
|
5231
|
+
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
|
5232
|
+
|
|
5233
|
+
from_float(
|
|
5234
|
+
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
|
|
5235
|
+
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
|
|
5236
|
+
}
|
|
5237
|
+
}
|
|
5238
|
+
}
|
|
5239
|
+
}
|
|
5240
|
+
|
|
5241
|
+
void ggml_compute_forward_set_rows(
|
|
5242
|
+
const ggml_compute_params * params,
|
|
5243
|
+
ggml_tensor * dst) {
|
|
5244
|
+
|
|
5245
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
5246
|
+
|
|
5247
|
+
switch (src0->type) {
|
|
5248
|
+
case GGML_TYPE_F32:
|
|
5249
|
+
{
|
|
5250
|
+
ggml_compute_forward_set_rows_f32(params, dst);
|
|
5251
|
+
} break;
|
|
5252
|
+
default:
|
|
5253
|
+
{
|
|
5254
|
+
GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
|
|
5255
|
+
}
|
|
5256
|
+
}
|
|
5257
|
+
}
|
|
5258
|
+
|
|
4473
5259
|
// ggml_compute_forward_get_rows_back
|
|
4474
5260
|
|
|
4475
5261
|
static void ggml_compute_forward_get_rows_back_f32_f16(
|
|
@@ -4500,7 +5286,7 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
|
|
|
4500
5286
|
|
|
4501
5287
|
for (int j = 0; j < nc; ++j) {
|
|
4502
5288
|
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
|
|
4503
|
-
((float *) ((char *) dst->data + r*dst->nb[1]))[j] +=
|
|
5289
|
+
((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
|
|
4504
5290
|
}
|
|
4505
5291
|
}
|
|
4506
5292
|
}
|
|
@@ -4744,14 +5530,17 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
4744
5530
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
4745
5531
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
4746
5532
|
|
|
4747
|
-
// TODO: handle transposed/permuted matrices
|
|
4748
|
-
|
|
4749
5533
|
const int ith = params->ith;
|
|
4750
5534
|
const int nth = params->nth;
|
|
4751
5535
|
|
|
4752
5536
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
4753
5537
|
|
|
4754
|
-
|
|
5538
|
+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
|
5539
|
+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
|
5540
|
+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
|
5541
|
+
|
|
5542
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
|
5543
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
|
4755
5544
|
|
|
4756
5545
|
// TODO: is this supposed to be ceil instead of floor?
|
|
4757
5546
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
|
@@ -4761,68 +5550,66 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
4761
5550
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
4762
5551
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
4763
5552
|
|
|
4764
|
-
|
|
4765
|
-
const int nr = ggml_nrows(src0);
|
|
4766
|
-
|
|
4767
|
-
// rows per thread
|
|
4768
|
-
const int dr = (nr + nth - 1)/nth;
|
|
4769
|
-
|
|
4770
|
-
// row range for this thread
|
|
4771
|
-
const int ir0 = dr*ith;
|
|
4772
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
4773
|
-
|
|
4774
|
-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
|
5553
|
+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
4775
5554
|
|
|
4776
5555
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
4777
5556
|
|
|
4778
|
-
for (
|
|
4779
|
-
|
|
4780
|
-
|
|
4781
|
-
|
|
4782
|
-
|
|
4783
|
-
|
|
4784
|
-
|
|
4785
|
-
|
|
4786
|
-
|
|
4787
|
-
|
|
4788
|
-
|
|
4789
|
-
|
|
4790
|
-
|
|
4791
|
-
|
|
4792
|
-
|
|
4793
|
-
|
|
4794
|
-
|
|
4795
|
-
|
|
4796
|
-
|
|
4797
|
-
|
|
4798
|
-
|
|
4799
|
-
|
|
5557
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
5558
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
5559
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
5560
|
+
const int64_t i11 = i01;
|
|
5561
|
+
const int64_t i12 = i02%ne12;
|
|
5562
|
+
const int64_t i13 = i03%ne13;
|
|
5563
|
+
|
|
5564
|
+
// ALiBi
|
|
5565
|
+
const uint32_t h = i02; // head
|
|
5566
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
5567
|
+
|
|
5568
|
+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
5569
|
+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
5570
|
+
|
|
5571
|
+
// broadcast the mask across rows
|
|
5572
|
+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5573
|
+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5574
|
+
|
|
5575
|
+
ggml_vec_cpy_f32 (ne00, wp, sp);
|
|
5576
|
+
ggml_vec_scale_f32(ne00, wp, scale);
|
|
5577
|
+
if (mp_f32) {
|
|
5578
|
+
if (use_f16) {
|
|
5579
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5580
|
+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
|
5581
|
+
}
|
|
5582
|
+
} else {
|
|
5583
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5584
|
+
wp[i] += slope*mp_f32[i];
|
|
5585
|
+
}
|
|
5586
|
+
}
|
|
4800
5587
|
}
|
|
4801
|
-
}
|
|
4802
|
-
}
|
|
4803
5588
|
|
|
4804
5589
|
#ifndef NDEBUG
|
|
4805
|
-
|
|
4806
|
-
|
|
4807
|
-
|
|
4808
|
-
|
|
5590
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5591
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
|
5592
|
+
assert(!isnan(wp[i]));
|
|
5593
|
+
}
|
|
4809
5594
|
#endif
|
|
4810
5595
|
|
|
4811
|
-
|
|
4812
|
-
|
|
5596
|
+
float max = -INFINITY;
|
|
5597
|
+
ggml_vec_max_f32(ne00, &max, wp);
|
|
4813
5598
|
|
|
4814
|
-
|
|
4815
|
-
|
|
5599
|
+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
|
5600
|
+
assert(sum > 0.0);
|
|
4816
5601
|
|
|
4817
|
-
|
|
4818
|
-
|
|
5602
|
+
sum = 1.0/sum;
|
|
5603
|
+
ggml_vec_scale_f32(ne00, dp, sum);
|
|
4819
5604
|
|
|
4820
5605
|
#ifndef NDEBUG
|
|
4821
|
-
|
|
4822
|
-
|
|
4823
|
-
|
|
4824
|
-
|
|
5606
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5607
|
+
assert(!isnan(dp[i]));
|
|
5608
|
+
assert(!isinf(dp[i]));
|
|
5609
|
+
}
|
|
4825
5610
|
#endif
|
|
5611
|
+
}
|
|
5612
|
+
}
|
|
4826
5613
|
}
|
|
4827
5614
|
}
|
|
4828
5615
|
|
|
@@ -5018,8 +5805,8 @@ static void ggml_compute_forward_clamp_f16(
|
|
|
5018
5805
|
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
|
5019
5806
|
|
|
5020
5807
|
for (int i = 0; i < nc; i++) {
|
|
5021
|
-
float v =
|
|
5022
|
-
dst_ptr[i] =
|
|
5808
|
+
float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
|
|
5809
|
+
dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
|
|
5023
5810
|
}
|
|
5024
5811
|
}
|
|
5025
5812
|
}
|
|
@@ -5476,11 +6263,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5476
6263
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5477
6264
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5478
6265
|
|
|
5479
|
-
const float x0 =
|
|
5480
|
-
const float x1 =
|
|
6266
|
+
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
6267
|
+
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
5481
6268
|
|
|
5482
|
-
dst_data[0] =
|
|
5483
|
-
dst_data[n_dims] =
|
|
6269
|
+
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6270
|
+
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5484
6271
|
}
|
|
5485
6272
|
} else {
|
|
5486
6273
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
@@ -5492,11 +6279,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5492
6279
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5493
6280
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5494
6281
|
|
|
5495
|
-
const float x0 =
|
|
5496
|
-
const float x1 =
|
|
6282
|
+
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
6283
|
+
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
|
5497
6284
|
|
|
5498
|
-
dst_data[0] =
|
|
5499
|
-
dst_data[n_dims/2] =
|
|
6285
|
+
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6286
|
+
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5500
6287
|
}
|
|
5501
6288
|
}
|
|
5502
6289
|
} else {
|
|
@@ -5507,11 +6294,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5507
6294
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5508
6295
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5509
6296
|
|
|
5510
|
-
const float x0 =
|
|
5511
|
-
const float x1 =
|
|
6297
|
+
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
6298
|
+
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
|
|
5512
6299
|
|
|
5513
|
-
dst_data[0] =
|
|
5514
|
-
dst_data[1] =
|
|
6300
|
+
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6301
|
+
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5515
6302
|
}
|
|
5516
6303
|
}
|
|
5517
6304
|
|
|
@@ -5525,11 +6312,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
5525
6312
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
5526
6313
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
5527
6314
|
|
|
5528
|
-
const float x0 =
|
|
5529
|
-
const float x1 =
|
|
6315
|
+
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
6316
|
+
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
5530
6317
|
|
|
5531
|
-
dst_data[0] =
|
|
5532
|
-
dst_data[n_dims] =
|
|
6318
|
+
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6319
|
+
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
5533
6320
|
}
|
|
5534
6321
|
} else {
|
|
5535
6322
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
@@ -5640,7 +6427,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
|
|
5640
6427
|
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
|
5641
6428
|
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
|
5642
6429
|
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
|
5643
|
-
dst_data[i10*ne11 + i11] =
|
|
6430
|
+
dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
|
|
5644
6431
|
}
|
|
5645
6432
|
}
|
|
5646
6433
|
}
|
|
@@ -5933,7 +6720,7 @@ static void ggml_compute_forward_im2col_f16(
|
|
|
5933
6720
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
5934
6721
|
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
|
5935
6722
|
} else {
|
|
5936
|
-
dst_data[iic*(KH*KW) + ikh*KW + ikw] =
|
|
6723
|
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
|
5937
6724
|
}
|
|
5938
6725
|
}
|
|
5939
6726
|
}
|
|
@@ -6058,6 +6845,186 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
6058
6845
|
}
|
|
6059
6846
|
}
|
|
6060
6847
|
|
|
6848
|
+
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
6849
|
+
void * a, void * b, float * c) {
|
|
6850
|
+
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
|
6851
|
+
struct ggml_tensor src1 = {};
|
|
6852
|
+
src1.type = type;
|
|
6853
|
+
src1.ne[0] = k;
|
|
6854
|
+
src1.ne[1] = m;
|
|
6855
|
+
src1.ne[2] = 1;
|
|
6856
|
+
src1.ne[3] = 1;
|
|
6857
|
+
src1.nb[0] = traits->type_size;
|
|
6858
|
+
src1.nb[1] = k * traits->type_size;
|
|
6859
|
+
src1.nb[2] = src1.nb[1];
|
|
6860
|
+
src1.nb[3] = src1.nb[2];
|
|
6861
|
+
src1.data = a;
|
|
6862
|
+
|
|
6863
|
+
struct ggml_tensor src0 = {};
|
|
6864
|
+
src0.type = type;
|
|
6865
|
+
src0.ne[0] = k;
|
|
6866
|
+
src0.ne[1] = n;
|
|
6867
|
+
src0.ne[2] = 1;
|
|
6868
|
+
src0.ne[3] = 1;
|
|
6869
|
+
src0.nb[0] = traits->type_size;
|
|
6870
|
+
src0.nb[1] = k * traits->type_size;
|
|
6871
|
+
src0.nb[2] = src0.nb[1];
|
|
6872
|
+
src0.nb[3] = src0.nb[2];
|
|
6873
|
+
src0.data = b;
|
|
6874
|
+
|
|
6875
|
+
struct ggml_tensor dst = {};
|
|
6876
|
+
dst.ne[0] = n;
|
|
6877
|
+
dst.ne[1] = m;
|
|
6878
|
+
dst.ne[2] = 1;
|
|
6879
|
+
dst.ne[3] = 1;
|
|
6880
|
+
dst.nb[0] = sizeof(float);
|
|
6881
|
+
dst.nb[1] = n * sizeof(float);
|
|
6882
|
+
dst.nb[2] = dst.nb[1];
|
|
6883
|
+
dst.nb[3] = dst.nb[2];
|
|
6884
|
+
dst.data = c;
|
|
6885
|
+
dst.src[0] = &src0;
|
|
6886
|
+
dst.src[1] = &src1;
|
|
6887
|
+
|
|
6888
|
+
ggml_compute_forward_mul_mat(params, &dst);
|
|
6889
|
+
}
|
|
6890
|
+
|
|
6891
|
+
// ggml_compute_forward_conv_2d
|
|
6892
|
+
|
|
6893
|
+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
|
6894
|
+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
|
6895
|
+
const ggml_tensor * src, // [W, H, C, N]
|
|
6896
|
+
ggml_tensor * dst, // [OW, OH, OC, N]
|
|
6897
|
+
ggml_type kernel_type) {
|
|
6898
|
+
|
|
6899
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
|
6900
|
+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
|
6901
|
+
GGML_ASSERT(kernel->type == kernel_type);
|
|
6902
|
+
|
|
6903
|
+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
|
6904
|
+
|
|
6905
|
+
const int32_t stride_x = dst->op_params[0];
|
|
6906
|
+
const int32_t stride_y = dst->op_params[1];
|
|
6907
|
+
const int32_t pad_x = dst->op_params[2];
|
|
6908
|
+
const int32_t pad_y = dst->op_params[3];
|
|
6909
|
+
const int32_t dilation_x = dst->op_params[4];
|
|
6910
|
+
const int32_t dilation_y = dst->op_params[5];
|
|
6911
|
+
|
|
6912
|
+
const int64_t c_in = src->ne[2];
|
|
6913
|
+
const int64_t c_out = kernel->ne[3];
|
|
6914
|
+
GGML_ASSERT(c_in == kernel->ne[2]);
|
|
6915
|
+
|
|
6916
|
+
const int64_t src_w = src->ne[0];
|
|
6917
|
+
const int64_t src_h = src->ne[1];
|
|
6918
|
+
const int64_t knl_w = kernel->ne[0];
|
|
6919
|
+
const int64_t knl_h = kernel->ne[1];
|
|
6920
|
+
const int64_t dst_w = dst->ne[0];
|
|
6921
|
+
const int64_t dst_h = dst->ne[1];
|
|
6922
|
+
|
|
6923
|
+
const float * src_data = (float *) src->data;
|
|
6924
|
+
void * knl_data = kernel->data;
|
|
6925
|
+
float * dst_data = (float *) dst->data;
|
|
6926
|
+
|
|
6927
|
+
const int64_t knl_n = knl_w * knl_h * c_in;
|
|
6928
|
+
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
|
|
6929
|
+
|
|
6930
|
+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
|
|
6931
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
|
6932
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
|
6933
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
|
6934
|
+
|
|
6935
|
+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
|
6936
|
+
|
|
6937
|
+
void * tmp = params->wdata;
|
|
6938
|
+
|
|
6939
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
|
6940
|
+
|
|
6941
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
|
6942
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
|
|
6943
|
+
patch_total);
|
|
6944
|
+
const int64_t patch_n = patch_end_batch - patch_start_batch;
|
|
6945
|
+
|
|
6946
|
+
const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
|
|
6947
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
|
6948
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
|
6949
|
+
|
|
6950
|
+
//im2col for a patch
|
|
6951
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
|
6952
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
|
6953
|
+
const int64_t src_x = (p / dst_w) % dst_h;
|
|
6954
|
+
const int64_t src_y = p % dst_w;
|
|
6955
|
+
|
|
6956
|
+
const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
|
|
6957
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
|
|
6958
|
+
|
|
6959
|
+
for (int64_t ic = 0; ic < c_in; ++ic) {
|
|
6960
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
|
6961
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
|
6962
|
+
const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
|
|
6963
|
+
const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
|
|
6964
|
+
|
|
6965
|
+
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
|
|
6966
|
+
|
|
6967
|
+
float src_val;
|
|
6968
|
+
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
|
6969
|
+
src_val = 0.0f;
|
|
6970
|
+
} else {
|
|
6971
|
+
const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
|
|
6972
|
+
src_val = *src_ptr;
|
|
6973
|
+
}
|
|
6974
|
+
|
|
6975
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
|
6976
|
+
if (kernel_type == GGML_TYPE_F32) {
|
|
6977
|
+
*(float *) element_ptr = src_val;
|
|
6978
|
+
} else if (kernel_type == GGML_TYPE_F16) {
|
|
6979
|
+
*(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
|
6980
|
+
}
|
|
6981
|
+
}
|
|
6982
|
+
}
|
|
6983
|
+
}
|
|
6984
|
+
} // patches handled by this thread
|
|
6985
|
+
|
|
6986
|
+
ggml_barrier(params->threadpool);
|
|
6987
|
+
|
|
6988
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
|
|
6989
|
+
|
|
6990
|
+
GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
|
|
6991
|
+
|
|
6992
|
+
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
|
|
6993
|
+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
|
|
6994
|
+
|
|
6995
|
+
ggml_barrier(params->threadpool);
|
|
6996
|
+
|
|
6997
|
+
|
|
6998
|
+
//permute back [OC, N, OH, OW] to [N, OC, OH, OW]
|
|
6999
|
+
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
|
|
7000
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
|
7001
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
|
|
7002
|
+
|
|
7003
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
|
7004
|
+
const int64_t p = patch_start_batch + i;
|
|
7005
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
|
7006
|
+
const int64_t dst_y = (p / dst_w) % dst_h;
|
|
7007
|
+
const int64_t dst_x = p % dst_w;
|
|
7008
|
+
|
|
7009
|
+
for (int64_t oc = 0; oc < c_out; ++oc) {
|
|
7010
|
+
const float value = gemm_output[i * c_out + oc];
|
|
7011
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
|
|
7012
|
+
*dst_ptr = value;
|
|
7013
|
+
}
|
|
7014
|
+
}
|
|
7015
|
+
}
|
|
7016
|
+
}
|
|
7017
|
+
|
|
7018
|
+
void ggml_compute_forward_conv_2d(
|
|
7019
|
+
const ggml_compute_params * params,
|
|
7020
|
+
ggml_tensor * dst) {
|
|
7021
|
+
|
|
7022
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7023
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
7024
|
+
|
|
7025
|
+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
|
7026
|
+
}
|
|
7027
|
+
|
|
6061
7028
|
// ggml_compute_forward_conv_transpose_2d
|
|
6062
7029
|
|
|
6063
7030
|
void ggml_compute_forward_conv_transpose_2d(
|
|
@@ -6109,7 +7076,7 @@ void ggml_compute_forward_conv_transpose_2d(
|
|
|
6109
7076
|
const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
|
|
6110
7077
|
ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
|
|
6111
7078
|
for (int i10 = 0; i10 < ne10; i10++) {
|
|
6112
|
-
dst_data[i10*ne12 + i12] =
|
|
7079
|
+
dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
|
|
6113
7080
|
}
|
|
6114
7081
|
}
|
|
6115
7082
|
}
|
|
@@ -6358,7 +7325,7 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|
|
6358
7325
|
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
|
6359
7326
|
}
|
|
6360
7327
|
for (int ki = 0; ki < k; ++ki) {
|
|
6361
|
-
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] :
|
|
7328
|
+
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
|
6362
7329
|
switch (op) {
|
|
6363
7330
|
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
|
|
6364
7331
|
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
|
|
@@ -6450,7 +7417,7 @@ void ggml_compute_forward_pool_2d(
|
|
|
6450
7417
|
for (int kx = 0; kx < k0; ++kx) {
|
|
6451
7418
|
int j = ix + kx;
|
|
6452
7419
|
if (j < 0 || j >= src->ne[0]) continue;
|
|
6453
|
-
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] :
|
|
7420
|
+
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
|
6454
7421
|
switch (op) {
|
|
6455
7422
|
case GGML_OP_POOL_AVG: *out += srow_j; break;
|
|
6456
7423
|
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
|
|
@@ -6538,7 +7505,7 @@ void ggml_compute_forward_pool_2d_back(
|
|
|
6538
7505
|
}
|
|
6539
7506
|
|
|
6540
7507
|
const float val = dst->type == GGML_TYPE_F32 ?
|
|
6541
|
-
((const float *) drowf)[j] :
|
|
7508
|
+
((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
|
|
6542
7509
|
if (val <= maxval) {
|
|
6543
7510
|
continue;
|
|
6544
7511
|
}
|
|
@@ -6558,7 +7525,7 @@ void ggml_compute_forward_pool_2d_back(
|
|
|
6558
7525
|
if (dst->type == GGML_TYPE_F32) {
|
|
6559
7526
|
((float *) drow)[j] += grad0;
|
|
6560
7527
|
} else {
|
|
6561
|
-
((ggml_fp16_t *) drow)[j] =
|
|
7528
|
+
((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
|
|
6562
7529
|
}
|
|
6563
7530
|
} else if (op == GGML_OP_POOL_AVG) {
|
|
6564
7531
|
const float grad = grad0 / ka;
|
|
@@ -6577,7 +7544,7 @@ void ggml_compute_forward_pool_2d_back(
|
|
|
6577
7544
|
if (dst->type == GGML_TYPE_F32) {
|
|
6578
7545
|
((float *) drow)[j] += grad;
|
|
6579
7546
|
} else {
|
|
6580
|
-
((ggml_fp16_t *) drow)[j] +=
|
|
7547
|
+
((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
|
|
6581
7548
|
}
|
|
6582
7549
|
}
|
|
6583
7550
|
}
|
|
@@ -6608,12 +7575,13 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
6608
7575
|
|
|
6609
7576
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
6610
7577
|
|
|
6611
|
-
|
|
6612
|
-
|
|
6613
|
-
|
|
6614
|
-
|
|
7578
|
+
float sf0 = (float)ne0/src0->ne[0];
|
|
7579
|
+
float sf1 = (float)ne1/src0->ne[1];
|
|
7580
|
+
float sf2 = (float)ne2/src0->ne[2];
|
|
7581
|
+
float sf3 = (float)ne3/src0->ne[3];
|
|
6615
7582
|
|
|
6616
|
-
const
|
|
7583
|
+
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
|
7584
|
+
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
6617
7585
|
|
|
6618
7586
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
6619
7587
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
@@ -6634,8 +7602,12 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
6634
7602
|
}
|
|
6635
7603
|
}
|
|
6636
7604
|
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
6637
|
-
|
|
6638
|
-
|
|
7605
|
+
float pixel_offset = 0.5f;
|
|
7606
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7607
|
+
pixel_offset = 0.0f;
|
|
7608
|
+
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
|
7609
|
+
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
|
7610
|
+
}
|
|
6639
7611
|
|
|
6640
7612
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
6641
7613
|
const int64_t i03 = i3 / sf3;
|
|
@@ -6793,6 +7765,73 @@ void ggml_compute_forward_pad_reflect_1d(
|
|
|
6793
7765
|
}
|
|
6794
7766
|
}
|
|
6795
7767
|
|
|
7768
|
+
// ggml_compute_forward_roll
|
|
7769
|
+
|
|
7770
|
+
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
|
|
7771
|
+
if (i < 0) {
|
|
7772
|
+
return i + ne;
|
|
7773
|
+
} else if (i >= ne) {
|
|
7774
|
+
return i - ne;
|
|
7775
|
+
}
|
|
7776
|
+
return i;
|
|
7777
|
+
}
|
|
7778
|
+
|
|
7779
|
+
static void ggml_compute_forward_roll_f32(
|
|
7780
|
+
const ggml_compute_params * params,
|
|
7781
|
+
ggml_tensor * dst) {
|
|
7782
|
+
|
|
7783
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7784
|
+
const float * src_data = (const float *) src0->data;
|
|
7785
|
+
float * dst_data = (float *) dst->data;
|
|
7786
|
+
|
|
7787
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
7788
|
+
|
|
7789
|
+
const int s0 = ggml_get_op_params_i32(dst, 0);
|
|
7790
|
+
const int s1 = ggml_get_op_params_i32(dst, 1);
|
|
7791
|
+
const int s2 = ggml_get_op_params_i32(dst, 2);
|
|
7792
|
+
const int s3 = ggml_get_op_params_i32(dst, 3);
|
|
7793
|
+
|
|
7794
|
+
const int64_t total = ne1 * ne2 * ne3;
|
|
7795
|
+
const int64_t per_thread = (total + params->nth) / params->nth;
|
|
7796
|
+
const int64_t start = params->ith * per_thread;
|
|
7797
|
+
const int64_t end = std::min(start + per_thread, total);
|
|
7798
|
+
|
|
7799
|
+
for (int64_t i = start; i < end; ++i) {
|
|
7800
|
+
const int64_t i1 = i % ne1;
|
|
7801
|
+
const int64_t i2 = (i / ne1) % ne2;
|
|
7802
|
+
const int64_t i3 = i / (ne2 * ne1);
|
|
7803
|
+
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
|
|
7804
|
+
|
|
7805
|
+
const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
|
|
7806
|
+
const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
|
|
7807
|
+
const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
|
|
7808
|
+
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
|
|
7809
|
+
|
|
7810
|
+
const int64_t s = ggml_wrap_index(-s0, ne00);
|
|
7811
|
+
const int64_t n = ne00 - s;
|
|
7812
|
+
ggml_vec_cpy_f32(n, dst_row, src_row + s);
|
|
7813
|
+
ggml_vec_cpy_f32(s, dst_row + n, src_row);
|
|
7814
|
+
}
|
|
7815
|
+
}
|
|
7816
|
+
|
|
7817
|
+
void ggml_compute_forward_roll(
|
|
7818
|
+
const ggml_compute_params * params,
|
|
7819
|
+
ggml_tensor * dst) {
|
|
7820
|
+
|
|
7821
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7822
|
+
|
|
7823
|
+
switch (src0->type) {
|
|
7824
|
+
case GGML_TYPE_F32:
|
|
7825
|
+
{
|
|
7826
|
+
ggml_compute_forward_roll_f32(params, dst);
|
|
7827
|
+
} break;
|
|
7828
|
+
default:
|
|
7829
|
+
{
|
|
7830
|
+
GGML_ABORT("fatal error");
|
|
7831
|
+
}
|
|
7832
|
+
}
|
|
7833
|
+
}
|
|
7834
|
+
|
|
6796
7835
|
// ggml_compute_forward_arange
|
|
6797
7836
|
|
|
6798
7837
|
static void ggml_compute_forward_arange_f32(
|
|
@@ -7026,7 +8065,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7026
8065
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
7027
8066
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
7028
8067
|
|
|
7029
|
-
ggml_type
|
|
8068
|
+
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
|
7030
8069
|
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
|
7031
8070
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
|
7032
8071
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
|
@@ -7058,7 +8097,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7058
8097
|
memset(VKQ32, 0, DV*sizeof(float));
|
|
7059
8098
|
}
|
|
7060
8099
|
|
|
7061
|
-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
|
8100
|
+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
|
7062
8101
|
|
|
7063
8102
|
// k indices
|
|
7064
8103
|
const int ik3 = iq3 / rk3;
|
|
@@ -7075,7 +8114,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7075
8114
|
// loop over n_kv and n_head_kv
|
|
7076
8115
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
|
7077
8116
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
|
7078
|
-
const float mv = mp ? slope*
|
|
8117
|
+
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
|
7079
8118
|
if (mv == -INFINITY) {
|
|
7080
8119
|
continue;
|
|
7081
8120
|
}
|
|
@@ -7143,7 +8182,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7143
8182
|
|
|
7144
8183
|
if (v->type == GGML_TYPE_F16) {
|
|
7145
8184
|
for (int64_t d = 0; d < DV; ++d) {
|
|
7146
|
-
VKQ32[d] =
|
|
8185
|
+
VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
|
|
7147
8186
|
}
|
|
7148
8187
|
}
|
|
7149
8188
|
|
|
@@ -7596,120 +8635,210 @@ void ggml_compute_forward_ssm_conv(
|
|
|
7596
8635
|
static void ggml_compute_forward_ssm_scan_f32(
|
|
7597
8636
|
const ggml_compute_params * params,
|
|
7598
8637
|
ggml_tensor * dst) {
|
|
7599
|
-
const ggml_tensor * src0 = dst->src[0]; // s
|
|
7600
|
-
const ggml_tensor * src1 = dst->src[1]; // x
|
|
7601
|
-
const ggml_tensor * src2 = dst->src[2]; // dt
|
|
7602
|
-
const ggml_tensor * src3 = dst->src[3]; // A
|
|
7603
|
-
const ggml_tensor * src4 = dst->src[4]; // B
|
|
7604
|
-
const ggml_tensor * src5 = dst->src[5]; // C
|
|
8638
|
+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
|
8639
|
+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
|
8640
|
+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
|
8641
|
+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
|
8642
|
+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8643
|
+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8644
|
+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
|
7605
8645
|
|
|
7606
8646
|
const int ith = params->ith;
|
|
7607
8647
|
const int nth = params->nth;
|
|
7608
8648
|
|
|
7609
|
-
const int64_t nc
|
|
7610
|
-
const int64_t nr
|
|
7611
|
-
const int64_t
|
|
7612
|
-
const int64_t
|
|
8649
|
+
const int64_t nc = src0->ne[0]; // d_state
|
|
8650
|
+
const int64_t nr = src0->ne[1]; // dim
|
|
8651
|
+
const int64_t nh = src1->ne[1]; // n_head
|
|
8652
|
+
const int64_t ng = src4->ne[1];
|
|
8653
|
+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
|
8654
|
+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
|
8655
|
+
|
|
8656
|
+
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
|
8657
|
+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
|
7613
8658
|
|
|
7614
|
-
GGML_ASSERT(ggml_nelements(src1) +
|
|
8659
|
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
|
7615
8660
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
7616
8661
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
7617
8662
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
7618
8663
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
|
7619
8664
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
7620
8665
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
7621
|
-
|
|
7622
|
-
|
|
7623
|
-
|
|
7624
|
-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
|
7625
|
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
7626
|
-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
|
8666
|
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
8667
|
+
// allows optimizing the modulo since n_group should be a power of 2
|
|
8668
|
+
GGML_ASSERT((ng & -ng) == ng);
|
|
7627
8669
|
|
|
7628
|
-
//
|
|
7629
|
-
const int
|
|
8670
|
+
// heads per thread
|
|
8671
|
+
const int dh = (nh + nth - 1)/nth;
|
|
7630
8672
|
|
|
7631
|
-
//
|
|
7632
|
-
const int
|
|
7633
|
-
const int
|
|
7634
|
-
|
|
8673
|
+
// head range for this thread
|
|
8674
|
+
const int ih0 = dh*ith;
|
|
8675
|
+
const int ih1 = MIN(ih0 + dh, nh);
|
|
8676
|
+
|
|
8677
|
+
const int32_t * ids = (const int32_t *) src6->data;
|
|
7635
8678
|
|
|
7636
|
-
|
|
7637
|
-
|
|
7638
|
-
|
|
7639
|
-
|
|
7640
|
-
|
|
7641
|
-
|
|
7642
|
-
|
|
7643
|
-
|
|
7644
|
-
|
|
7645
|
-
|
|
7646
|
-
|
|
7647
|
-
|
|
7648
|
-
|
|
7649
|
-
|
|
7650
|
-
|
|
7651
|
-
//
|
|
7652
|
-
for (int
|
|
7653
|
-
|
|
7654
|
-
float
|
|
7655
|
-
|
|
7656
|
-
|
|
7657
|
-
|
|
7658
|
-
|
|
7659
|
-
|
|
7660
|
-
|
|
7661
|
-
|
|
7662
|
-
|
|
7663
|
-
|
|
7664
|
-
|
|
7665
|
-
|
|
7666
|
-
|
|
7667
|
-
|
|
7668
|
-
|
|
7669
|
-
|
|
7670
|
-
|
|
7671
|
-
|
|
7672
|
-
|
|
8679
|
+
for (int i3 = 0; i3 < ns; ++i3) {
|
|
8680
|
+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
|
8681
|
+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
|
8682
|
+
|
|
8683
|
+
for (int i2 = 0; i2 < nt; ++i2) {
|
|
8684
|
+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
|
8685
|
+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
|
8686
|
+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
|
8687
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
|
8688
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
|
8689
|
+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
|
8690
|
+
|
|
8691
|
+
if (src3->ne[0] == 1) {
|
|
8692
|
+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
|
8693
|
+
|
|
8694
|
+
// n_head
|
|
8695
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8696
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8697
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8698
|
+
const float dA = expf(dt_soft_plus * A[h]);
|
|
8699
|
+
|
|
8700
|
+
// dim
|
|
8701
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8702
|
+
const int ii = i1 + h*nr;
|
|
8703
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8704
|
+
float sumf = 0.0f;
|
|
8705
|
+
#if defined(GGML_SIMD)
|
|
8706
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8707
|
+
const int ggml_f32_epr = svcntw();
|
|
8708
|
+
const int ggml_f32_step = 1 * ggml_f32_epr;
|
|
8709
|
+
|
|
8710
|
+
const int np = (nc & ~(ggml_f32_step - 1));
|
|
8711
|
+
|
|
8712
|
+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
|
8713
|
+
|
|
8714
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8715
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8716
|
+
|
|
8717
|
+
for (int i = 0; i < np; i += ggml_f32_step) {
|
|
8718
|
+
// TODO: maybe unroll more?
|
|
8719
|
+
for (int j = 0; j < 1; j++) {
|
|
8720
|
+
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
|
8721
|
+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8722
|
+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8723
|
+
|
|
8724
|
+
t0 = GGML_F32_VEC_MUL(t0, adA);
|
|
8725
|
+
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
|
8726
|
+
|
|
8727
|
+
t0 = GGML_F32_VEC_ADD(t0, t1);
|
|
8728
|
+
|
|
8729
|
+
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
|
8730
|
+
|
|
8731
|
+
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
|
8732
|
+
}
|
|
8733
|
+
}
|
|
8734
|
+
|
|
8735
|
+
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
|
8736
|
+
#else
|
|
8737
|
+
const int np = (nc & ~(GGML_F32_STEP - 1));
|
|
8738
|
+
|
|
8739
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
8740
|
+
|
|
8741
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8742
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8743
|
+
|
|
8744
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
|
8745
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
8746
|
+
GGML_F32_VEC az[GGML_F32_ARR];
|
|
8747
|
+
|
|
8748
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
|
8749
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
|
8750
|
+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
|
8751
|
+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8752
|
+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8753
|
+
|
|
8754
|
+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
|
8755
|
+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
|
8756
|
+
|
|
8757
|
+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
|
8758
|
+
|
|
8759
|
+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
|
8760
|
+
|
|
8761
|
+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
|
8762
|
+
}
|
|
8763
|
+
}
|
|
8764
|
+
|
|
8765
|
+
// reduce sum0..sum3 to sum0
|
|
8766
|
+
GGML_F32_VEC_REDUCE(sumf, sum);
|
|
8767
|
+
#endif
|
|
8768
|
+
#else
|
|
8769
|
+
const int np = 0;
|
|
8770
|
+
#endif
|
|
8771
|
+
// d_state
|
|
8772
|
+
for (int i0 = np; i0 < nc; ++i0) {
|
|
8773
|
+
const int i = i0 + ii*nc;
|
|
8774
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8775
|
+
// state = prev_state * dA + dB * x
|
|
8776
|
+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
|
8777
|
+
// y = rowwise_dotprod(state, C)
|
|
8778
|
+
sumf += state * C[ig];
|
|
8779
|
+
s[i] = state;
|
|
8780
|
+
}
|
|
8781
|
+
y[ii] = sumf;
|
|
7673
8782
|
}
|
|
7674
|
-
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
7675
8783
|
}
|
|
7676
|
-
}
|
|
7677
|
-
|
|
7678
|
-
|
|
7679
|
-
|
|
7680
|
-
|
|
7681
|
-
|
|
7682
|
-
|
|
7683
|
-
|
|
7684
|
-
|
|
7685
|
-
|
|
7686
|
-
|
|
7687
|
-
|
|
7688
|
-
|
|
7689
|
-
|
|
7690
|
-
|
|
7691
|
-
|
|
7692
|
-
|
|
7693
|
-
|
|
7694
|
-
|
|
7695
|
-
|
|
7696
|
-
|
|
7697
|
-
|
|
7698
|
-
|
|
7699
|
-
|
|
7700
|
-
|
|
7701
|
-
|
|
7702
|
-
|
|
7703
|
-
|
|
7704
|
-
|
|
7705
|
-
|
|
7706
|
-
|
|
8784
|
+
} else {
|
|
8785
|
+
// Mamba-1 has an element-wise decay factor for the states
|
|
8786
|
+
|
|
8787
|
+
// n_head
|
|
8788
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8789
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8790
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8791
|
+
|
|
8792
|
+
// dim
|
|
8793
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8794
|
+
const int ii = i1 + h*nr;
|
|
8795
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8796
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8797
|
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
|
8798
|
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
|
8799
|
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
|
8800
|
+
|
|
8801
|
+
// d_state
|
|
8802
|
+
// TODO: what happens when (d_state % svcntw()) != 0?
|
|
8803
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
8804
|
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
|
8805
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
|
8806
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
|
8807
|
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
|
8808
|
+
|
|
8809
|
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
8810
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
8811
|
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
|
8812
|
+
|
|
8813
|
+
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
|
8814
|
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
8815
|
+
|
|
8816
|
+
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
|
8817
|
+
}
|
|
8818
|
+
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8819
|
+
#else
|
|
8820
|
+
float sumf = 0.0f;
|
|
8821
|
+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
|
8822
|
+
// and also because expf is used within the loop.
|
|
8823
|
+
// d_state
|
|
8824
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
|
8825
|
+
const int i = i0 + ii*nc;
|
|
8826
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8827
|
+
// state = prev_state * dA + dB * x
|
|
8828
|
+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
|
8829
|
+
// y = rowwise_dotprod(state, C)
|
|
8830
|
+
sumf += state * C[ig];
|
|
8831
|
+
s[i] = state;
|
|
8832
|
+
}
|
|
8833
|
+
y[ii] = sumf;
|
|
8834
|
+
#endif
|
|
7707
8835
|
}
|
|
7708
|
-
y[i1] = sumf;
|
|
7709
8836
|
}
|
|
7710
8837
|
}
|
|
8838
|
+
// use the output as the source when it's not the first token-wise iteration
|
|
8839
|
+
s0 = s;
|
|
7711
8840
|
}
|
|
7712
|
-
|
|
8841
|
+
}
|
|
7713
8842
|
}
|
|
7714
8843
|
|
|
7715
8844
|
void ggml_compute_forward_ssm_scan(
|
|
@@ -7927,6 +9056,42 @@ void ggml_compute_forward_unary(
|
|
|
7927
9056
|
}
|
|
7928
9057
|
}
|
|
7929
9058
|
|
|
9059
|
+
//ggml_compute_forward_glu
|
|
9060
|
+
|
|
9061
|
+
void ggml_compute_forward_glu(
|
|
9062
|
+
const ggml_compute_params * params,
|
|
9063
|
+
ggml_tensor * dst) {
|
|
9064
|
+
|
|
9065
|
+
const ggml_glu_op op = ggml_get_glu_op(dst);
|
|
9066
|
+
|
|
9067
|
+
switch (op) {
|
|
9068
|
+
case GGML_GLU_OP_REGLU:
|
|
9069
|
+
{
|
|
9070
|
+
ggml_compute_forward_reglu(params, dst);
|
|
9071
|
+
} break;
|
|
9072
|
+
case GGML_GLU_OP_GEGLU:
|
|
9073
|
+
{
|
|
9074
|
+
ggml_compute_forward_geglu(params, dst);
|
|
9075
|
+
} break;
|
|
9076
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9077
|
+
{
|
|
9078
|
+
ggml_compute_forward_swiglu(params, dst);
|
|
9079
|
+
} break;
|
|
9080
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9081
|
+
{
|
|
9082
|
+
ggml_compute_forward_geglu_erf(params, dst);
|
|
9083
|
+
} break;
|
|
9084
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9085
|
+
{
|
|
9086
|
+
ggml_compute_forward_geglu_quick(params, dst);
|
|
9087
|
+
} break;
|
|
9088
|
+
default:
|
|
9089
|
+
{
|
|
9090
|
+
GGML_ABORT("fatal error");
|
|
9091
|
+
}
|
|
9092
|
+
}
|
|
9093
|
+
}
|
|
9094
|
+
|
|
7930
9095
|
// ggml_compute_forward_get_rel_pos
|
|
7931
9096
|
|
|
7932
9097
|
static void ggml_compute_forward_get_rel_pos_f16(
|