@novastera-oss/llamarn 0.2.9 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +17 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.h +4 -0
- package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +0 -40
- package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
- package/cpp/llama.cpp/src/llama-arch.h +18 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
- package/cpp/llama.cpp/src/llama-batch.h +8 -1
- package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
- package/cpp/llama.cpp/src/llama-graph.h +47 -60
- package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
- package/cpp/llama.cpp/src/llama-hparams.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
- package/cpp/llama.cpp/src/llama-model.h +18 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
- package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
- package/cpp/llama.cpp/src/llama-vocab.h +41 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +4 -0
- package/ios/include/llama.h +0 -40
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#include "ggml-cpu.h"
|
|
4
4
|
#include "ggml-impl.h"
|
|
5
5
|
#include "binary-ops.h"
|
|
6
|
+
#include "ggml.h"
|
|
6
7
|
#include "unary-ops.h"
|
|
7
8
|
#include "vec.h"
|
|
8
9
|
|
|
@@ -3184,6 +3185,721 @@ void ggml_compute_forward_silu_back(
|
|
|
3184
3185
|
}
|
|
3185
3186
|
}
|
|
3186
3187
|
|
|
3188
|
+
// ggml_compute_forward_reglu
|
|
3189
|
+
|
|
3190
|
+
static void ggml_compute_forward_reglu_f32(
|
|
3191
|
+
const ggml_compute_params * params,
|
|
3192
|
+
ggml_tensor * dst) {
|
|
3193
|
+
|
|
3194
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3195
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3196
|
+
char * src0_d = (char *) src0->data;
|
|
3197
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3198
|
+
const size_t src0_o = src0->nb[1];
|
|
3199
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3200
|
+
|
|
3201
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3202
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3203
|
+
|
|
3204
|
+
if (src1) {
|
|
3205
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3206
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3207
|
+
}
|
|
3208
|
+
|
|
3209
|
+
const int ith = params->ith;
|
|
3210
|
+
const int nth = params->nth;
|
|
3211
|
+
|
|
3212
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3213
|
+
const int nr = ggml_nrows(src0);
|
|
3214
|
+
|
|
3215
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3216
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3217
|
+
|
|
3218
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3219
|
+
|
|
3220
|
+
// rows per thread
|
|
3221
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3222
|
+
|
|
3223
|
+
// row range for this thread
|
|
3224
|
+
const int ir0 = dr*ith;
|
|
3225
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3226
|
+
|
|
3227
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3228
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3229
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3230
|
+
|
|
3231
|
+
if (!src1) {
|
|
3232
|
+
src0_p += swapped ? nc : 0;
|
|
3233
|
+
src1_p += swapped ? 0 : nc;
|
|
3234
|
+
}
|
|
3235
|
+
|
|
3236
|
+
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3237
|
+
|
|
3238
|
+
#ifndef NDEBUG
|
|
3239
|
+
for (int k = 0; k < nc; k++) {
|
|
3240
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3241
|
+
GGML_UNUSED(x);
|
|
3242
|
+
assert(!isnan(x));
|
|
3243
|
+
assert(!isinf(x));
|
|
3244
|
+
}
|
|
3245
|
+
#endif
|
|
3246
|
+
}
|
|
3247
|
+
}
|
|
3248
|
+
|
|
3249
|
+
static void ggml_compute_forward_reglu_f16(
|
|
3250
|
+
const ggml_compute_params * params,
|
|
3251
|
+
ggml_tensor * dst) {
|
|
3252
|
+
|
|
3253
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3254
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3255
|
+
char * src0_d = (char *) src0->data;
|
|
3256
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3257
|
+
const size_t src0_o = src0->nb[1];
|
|
3258
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3259
|
+
|
|
3260
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3261
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3262
|
+
|
|
3263
|
+
if (src1) {
|
|
3264
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3265
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3266
|
+
}
|
|
3267
|
+
|
|
3268
|
+
const int ith = params->ith;
|
|
3269
|
+
const int nth = params->nth;
|
|
3270
|
+
|
|
3271
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3272
|
+
const int nr = ggml_nrows(src0);
|
|
3273
|
+
|
|
3274
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3275
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3276
|
+
|
|
3277
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3278
|
+
|
|
3279
|
+
// rows per thread
|
|
3280
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3281
|
+
|
|
3282
|
+
// row range for this thread
|
|
3283
|
+
const int ir0 = dr*ith;
|
|
3284
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3285
|
+
|
|
3286
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3287
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3288
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3289
|
+
|
|
3290
|
+
if (!src1) {
|
|
3291
|
+
src0_p += swapped ? nc : 0;
|
|
3292
|
+
src1_p += swapped ? 0 : nc;
|
|
3293
|
+
}
|
|
3294
|
+
|
|
3295
|
+
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3296
|
+
|
|
3297
|
+
#ifndef NDEBUG
|
|
3298
|
+
for (int k = 0; k < nc; k++) {
|
|
3299
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3300
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3301
|
+
GGML_UNUSED(v);
|
|
3302
|
+
assert(!isnan(v));
|
|
3303
|
+
assert(!isinf(v));
|
|
3304
|
+
}
|
|
3305
|
+
#endif
|
|
3306
|
+
}
|
|
3307
|
+
}
|
|
3308
|
+
|
|
3309
|
+
static void ggml_compute_forward_reglu(
|
|
3310
|
+
const ggml_compute_params * params,
|
|
3311
|
+
ggml_tensor * dst) {
|
|
3312
|
+
|
|
3313
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3314
|
+
|
|
3315
|
+
switch (src0->type) {
|
|
3316
|
+
case GGML_TYPE_F32:
|
|
3317
|
+
{
|
|
3318
|
+
ggml_compute_forward_reglu_f32(params, dst);
|
|
3319
|
+
} break;
|
|
3320
|
+
case GGML_TYPE_F16:
|
|
3321
|
+
{
|
|
3322
|
+
ggml_compute_forward_reglu_f16(params, dst);
|
|
3323
|
+
} break;
|
|
3324
|
+
default:
|
|
3325
|
+
{
|
|
3326
|
+
GGML_ABORT("fatal error");
|
|
3327
|
+
}
|
|
3328
|
+
}
|
|
3329
|
+
}
|
|
3330
|
+
|
|
3331
|
+
// ggml_compute_forward_geglu
|
|
3332
|
+
|
|
3333
|
+
static void ggml_compute_forward_geglu_f32(
|
|
3334
|
+
const ggml_compute_params * params,
|
|
3335
|
+
ggml_tensor * dst) {
|
|
3336
|
+
|
|
3337
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3338
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3339
|
+
char * src0_d = (char *) src0->data;
|
|
3340
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3341
|
+
const size_t src0_o = src0->nb[1];
|
|
3342
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3343
|
+
|
|
3344
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3345
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3346
|
+
|
|
3347
|
+
if (src1) {
|
|
3348
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3349
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3350
|
+
}
|
|
3351
|
+
|
|
3352
|
+
const int ith = params->ith;
|
|
3353
|
+
const int nth = params->nth;
|
|
3354
|
+
|
|
3355
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3356
|
+
const int nr = ggml_nrows(src0);
|
|
3357
|
+
|
|
3358
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3359
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3360
|
+
|
|
3361
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3362
|
+
|
|
3363
|
+
// rows per thread
|
|
3364
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3365
|
+
|
|
3366
|
+
// row range for this thread
|
|
3367
|
+
const int ir0 = dr*ith;
|
|
3368
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3369
|
+
|
|
3370
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3371
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3372
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3373
|
+
|
|
3374
|
+
if (!src1) {
|
|
3375
|
+
src0_p += swapped ? nc : 0;
|
|
3376
|
+
src1_p += swapped ? 0 : nc;
|
|
3377
|
+
}
|
|
3378
|
+
|
|
3379
|
+
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3380
|
+
|
|
3381
|
+
#ifndef NDEBUG
|
|
3382
|
+
for (int k = 0; k < nc; k++) {
|
|
3383
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3384
|
+
GGML_UNUSED(x);
|
|
3385
|
+
assert(!isnan(x));
|
|
3386
|
+
assert(!isinf(x));
|
|
3387
|
+
}
|
|
3388
|
+
#endif
|
|
3389
|
+
}
|
|
3390
|
+
}
|
|
3391
|
+
|
|
3392
|
+
static void ggml_compute_forward_geglu_f16(
|
|
3393
|
+
const ggml_compute_params * params,
|
|
3394
|
+
ggml_tensor * dst) {
|
|
3395
|
+
|
|
3396
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3397
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3398
|
+
char * src0_d = (char *) src0->data;
|
|
3399
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3400
|
+
const size_t src0_o = src0->nb[1];
|
|
3401
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3402
|
+
|
|
3403
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3404
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3405
|
+
|
|
3406
|
+
if (src1) {
|
|
3407
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3408
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3409
|
+
}
|
|
3410
|
+
|
|
3411
|
+
const int ith = params->ith;
|
|
3412
|
+
const int nth = params->nth;
|
|
3413
|
+
|
|
3414
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3415
|
+
const int nr = ggml_nrows(src0);
|
|
3416
|
+
|
|
3417
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3418
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3419
|
+
|
|
3420
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3421
|
+
|
|
3422
|
+
// rows per thread
|
|
3423
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3424
|
+
|
|
3425
|
+
// row range for this thread
|
|
3426
|
+
const int ir0 = dr*ith;
|
|
3427
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3428
|
+
|
|
3429
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3430
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3431
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3432
|
+
|
|
3433
|
+
if (!src1) {
|
|
3434
|
+
src0_p += swapped ? nc : 0;
|
|
3435
|
+
src1_p += swapped ? 0 : nc;
|
|
3436
|
+
}
|
|
3437
|
+
|
|
3438
|
+
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3439
|
+
|
|
3440
|
+
#ifndef NDEBUG
|
|
3441
|
+
for (int k = 0; k < nc; k++) {
|
|
3442
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3443
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3444
|
+
GGML_UNUSED(v);
|
|
3445
|
+
assert(!isnan(v));
|
|
3446
|
+
assert(!isinf(v));
|
|
3447
|
+
}
|
|
3448
|
+
#endif
|
|
3449
|
+
}
|
|
3450
|
+
}
|
|
3451
|
+
|
|
3452
|
+
static void ggml_compute_forward_geglu(
|
|
3453
|
+
const ggml_compute_params * params,
|
|
3454
|
+
ggml_tensor * dst) {
|
|
3455
|
+
|
|
3456
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3457
|
+
|
|
3458
|
+
switch (src0->type) {
|
|
3459
|
+
case GGML_TYPE_F32:
|
|
3460
|
+
{
|
|
3461
|
+
ggml_compute_forward_geglu_f32(params, dst);
|
|
3462
|
+
} break;
|
|
3463
|
+
case GGML_TYPE_F16:
|
|
3464
|
+
{
|
|
3465
|
+
ggml_compute_forward_geglu_f16(params, dst);
|
|
3466
|
+
} break;
|
|
3467
|
+
default:
|
|
3468
|
+
{
|
|
3469
|
+
GGML_ABORT("fatal error");
|
|
3470
|
+
}
|
|
3471
|
+
}
|
|
3472
|
+
}
|
|
3473
|
+
|
|
3474
|
+
// ggml_compute_forward_swiglu
|
|
3475
|
+
|
|
3476
|
+
static void ggml_compute_forward_swiglu_f32(
|
|
3477
|
+
const ggml_compute_params * params,
|
|
3478
|
+
ggml_tensor * dst) {
|
|
3479
|
+
|
|
3480
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3481
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3482
|
+
char * src0_d = (char *) src0->data;
|
|
3483
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3484
|
+
const size_t src0_o = src0->nb[1];
|
|
3485
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3486
|
+
|
|
3487
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3488
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3489
|
+
|
|
3490
|
+
if (src1) {
|
|
3491
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3492
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3493
|
+
}
|
|
3494
|
+
|
|
3495
|
+
const int ith = params->ith;
|
|
3496
|
+
const int nth = params->nth;
|
|
3497
|
+
|
|
3498
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3499
|
+
const int nr = ggml_nrows(src0);
|
|
3500
|
+
|
|
3501
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3502
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3503
|
+
|
|
3504
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3505
|
+
|
|
3506
|
+
// rows per thread
|
|
3507
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3508
|
+
|
|
3509
|
+
// row range for this thread
|
|
3510
|
+
const int ir0 = dr*ith;
|
|
3511
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3512
|
+
|
|
3513
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3514
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3515
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3516
|
+
|
|
3517
|
+
if (!src1) {
|
|
3518
|
+
src0_p += swapped ? nc : 0;
|
|
3519
|
+
src1_p += swapped ? 0 : nc;
|
|
3520
|
+
}
|
|
3521
|
+
|
|
3522
|
+
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3523
|
+
|
|
3524
|
+
#ifndef NDEBUG
|
|
3525
|
+
for (int k = 0; k < nc; k++) {
|
|
3526
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3527
|
+
GGML_UNUSED(x);
|
|
3528
|
+
assert(!isnan(x));
|
|
3529
|
+
assert(!isinf(x));
|
|
3530
|
+
}
|
|
3531
|
+
#endif
|
|
3532
|
+
}
|
|
3533
|
+
}
|
|
3534
|
+
|
|
3535
|
+
static void ggml_compute_forward_swiglu_f16(
|
|
3536
|
+
const ggml_compute_params * params,
|
|
3537
|
+
ggml_tensor * dst) {
|
|
3538
|
+
|
|
3539
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3540
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3541
|
+
char * src0_d = (char *) src0->data;
|
|
3542
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3543
|
+
const size_t src0_o = src0->nb[1];
|
|
3544
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3545
|
+
|
|
3546
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3547
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3548
|
+
|
|
3549
|
+
if (src1) {
|
|
3550
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3551
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3552
|
+
}
|
|
3553
|
+
|
|
3554
|
+
const int ith = params->ith;
|
|
3555
|
+
const int nth = params->nth;
|
|
3556
|
+
|
|
3557
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3558
|
+
const int nr = ggml_nrows(src0);
|
|
3559
|
+
|
|
3560
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3561
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3562
|
+
|
|
3563
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3564
|
+
|
|
3565
|
+
// rows per thread
|
|
3566
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3567
|
+
|
|
3568
|
+
// row range for this thread
|
|
3569
|
+
const int ir0 = dr*ith;
|
|
3570
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3571
|
+
|
|
3572
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3573
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3574
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3575
|
+
|
|
3576
|
+
if (!src1) {
|
|
3577
|
+
src0_p += swapped ? nc : 0;
|
|
3578
|
+
src1_p += swapped ? 0 : nc;
|
|
3579
|
+
}
|
|
3580
|
+
|
|
3581
|
+
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3582
|
+
|
|
3583
|
+
#ifndef NDEBUG
|
|
3584
|
+
for (int k = 0; k < nc; k++) {
|
|
3585
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3586
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3587
|
+
GGML_UNUSED(v);
|
|
3588
|
+
assert(!isnan(v));
|
|
3589
|
+
assert(!isinf(v));
|
|
3590
|
+
}
|
|
3591
|
+
#endif
|
|
3592
|
+
}
|
|
3593
|
+
}
|
|
3594
|
+
|
|
3595
|
+
static void ggml_compute_forward_swiglu(
|
|
3596
|
+
const ggml_compute_params * params,
|
|
3597
|
+
ggml_tensor * dst) {
|
|
3598
|
+
|
|
3599
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3600
|
+
|
|
3601
|
+
switch (src0->type) {
|
|
3602
|
+
case GGML_TYPE_F32:
|
|
3603
|
+
{
|
|
3604
|
+
ggml_compute_forward_swiglu_f32(params, dst);
|
|
3605
|
+
} break;
|
|
3606
|
+
case GGML_TYPE_F16:
|
|
3607
|
+
{
|
|
3608
|
+
ggml_compute_forward_swiglu_f16(params, dst);
|
|
3609
|
+
} break;
|
|
3610
|
+
default:
|
|
3611
|
+
{
|
|
3612
|
+
GGML_ABORT("fatal error");
|
|
3613
|
+
}
|
|
3614
|
+
}
|
|
3615
|
+
}
|
|
3616
|
+
|
|
3617
|
+
// ggml_compute_forward_geglu_erf
|
|
3618
|
+
|
|
3619
|
+
static void ggml_compute_forward_geglu_erf_f32(
|
|
3620
|
+
const ggml_compute_params * params,
|
|
3621
|
+
ggml_tensor * dst) {
|
|
3622
|
+
|
|
3623
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3624
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3625
|
+
char * src0_d = (char *) src0->data;
|
|
3626
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3627
|
+
const size_t src0_o = src0->nb[1];
|
|
3628
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3629
|
+
|
|
3630
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3631
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3632
|
+
|
|
3633
|
+
if (src1) {
|
|
3634
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3635
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3636
|
+
}
|
|
3637
|
+
|
|
3638
|
+
const int ith = params->ith;
|
|
3639
|
+
const int nth = params->nth;
|
|
3640
|
+
|
|
3641
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3642
|
+
const int nr = ggml_nrows(src0);
|
|
3643
|
+
|
|
3644
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3645
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3646
|
+
|
|
3647
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3648
|
+
|
|
3649
|
+
// rows per thread
|
|
3650
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3651
|
+
|
|
3652
|
+
// row range for this thread
|
|
3653
|
+
const int ir0 = dr*ith;
|
|
3654
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3655
|
+
|
|
3656
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3657
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3658
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3659
|
+
|
|
3660
|
+
if (!src1) {
|
|
3661
|
+
src0_p += swapped ? nc : 0;
|
|
3662
|
+
src1_p += swapped ? 0 : nc;
|
|
3663
|
+
}
|
|
3664
|
+
|
|
3665
|
+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3666
|
+
|
|
3667
|
+
#ifndef NDEBUG
|
|
3668
|
+
for (int k = 0; k < nc; k++) {
|
|
3669
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3670
|
+
GGML_UNUSED(x);
|
|
3671
|
+
assert(!isnan(x));
|
|
3672
|
+
assert(!isinf(x));
|
|
3673
|
+
}
|
|
3674
|
+
#endif
|
|
3675
|
+
}
|
|
3676
|
+
}
|
|
3677
|
+
|
|
3678
|
+
static void ggml_compute_forward_geglu_erf_f16(
|
|
3679
|
+
const ggml_compute_params * params,
|
|
3680
|
+
ggml_tensor * dst) {
|
|
3681
|
+
|
|
3682
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3683
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3684
|
+
char * src0_d = (char *) src0->data;
|
|
3685
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3686
|
+
const size_t src0_o = src0->nb[1];
|
|
3687
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3688
|
+
|
|
3689
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3690
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3691
|
+
|
|
3692
|
+
if (src1) {
|
|
3693
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3694
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3695
|
+
}
|
|
3696
|
+
|
|
3697
|
+
const int ith = params->ith;
|
|
3698
|
+
const int nth = params->nth;
|
|
3699
|
+
|
|
3700
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3701
|
+
const int nr = ggml_nrows(src0);
|
|
3702
|
+
|
|
3703
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3704
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3705
|
+
|
|
3706
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3707
|
+
|
|
3708
|
+
// rows per thread
|
|
3709
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3710
|
+
|
|
3711
|
+
// row range for this thread
|
|
3712
|
+
const int ir0 = dr*ith;
|
|
3713
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3714
|
+
|
|
3715
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3716
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3717
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3718
|
+
|
|
3719
|
+
if (!src1) {
|
|
3720
|
+
src0_p += swapped ? nc : 0;
|
|
3721
|
+
src1_p += swapped ? 0 : nc;
|
|
3722
|
+
}
|
|
3723
|
+
|
|
3724
|
+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3725
|
+
|
|
3726
|
+
#ifndef NDEBUG
|
|
3727
|
+
for (int k = 0; k < nc; k++) {
|
|
3728
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3729
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3730
|
+
GGML_UNUSED(v);
|
|
3731
|
+
assert(!isnan(v));
|
|
3732
|
+
assert(!isinf(v));
|
|
3733
|
+
}
|
|
3734
|
+
#endif
|
|
3735
|
+
}
|
|
3736
|
+
}
|
|
3737
|
+
|
|
3738
|
+
static void ggml_compute_forward_geglu_erf(
|
|
3739
|
+
const ggml_compute_params * params,
|
|
3740
|
+
ggml_tensor * dst) {
|
|
3741
|
+
|
|
3742
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3743
|
+
|
|
3744
|
+
switch (src0->type) {
|
|
3745
|
+
case GGML_TYPE_F32:
|
|
3746
|
+
{
|
|
3747
|
+
ggml_compute_forward_geglu_erf_f32(params, dst);
|
|
3748
|
+
} break;
|
|
3749
|
+
case GGML_TYPE_F16:
|
|
3750
|
+
{
|
|
3751
|
+
ggml_compute_forward_geglu_erf_f16(params, dst);
|
|
3752
|
+
} break;
|
|
3753
|
+
default:
|
|
3754
|
+
{
|
|
3755
|
+
GGML_ABORT("fatal error");
|
|
3756
|
+
}
|
|
3757
|
+
}
|
|
3758
|
+
}
|
|
3759
|
+
|
|
3760
|
+
// ggml_compute_forward_geglu_quick
|
|
3761
|
+
|
|
3762
|
+
static void ggml_compute_forward_geglu_quick_f32(
|
|
3763
|
+
const ggml_compute_params * params,
|
|
3764
|
+
ggml_tensor * dst) {
|
|
3765
|
+
|
|
3766
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3767
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3768
|
+
char * src0_d = (char *) src0->data;
|
|
3769
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3770
|
+
const size_t src0_o = src0->nb[1];
|
|
3771
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3772
|
+
|
|
3773
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3774
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3775
|
+
|
|
3776
|
+
if (src1) {
|
|
3777
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3778
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3779
|
+
}
|
|
3780
|
+
|
|
3781
|
+
const int ith = params->ith;
|
|
3782
|
+
const int nth = params->nth;
|
|
3783
|
+
|
|
3784
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3785
|
+
const int nr = ggml_nrows(src0);
|
|
3786
|
+
|
|
3787
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3788
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3789
|
+
|
|
3790
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3791
|
+
|
|
3792
|
+
// rows per thread
|
|
3793
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3794
|
+
|
|
3795
|
+
// row range for this thread
|
|
3796
|
+
const int ir0 = dr*ith;
|
|
3797
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3798
|
+
|
|
3799
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3800
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3801
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3802
|
+
|
|
3803
|
+
if (!src1) {
|
|
3804
|
+
src0_p += swapped ? nc : 0;
|
|
3805
|
+
src1_p += swapped ? 0 : nc;
|
|
3806
|
+
}
|
|
3807
|
+
|
|
3808
|
+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3809
|
+
|
|
3810
|
+
#ifndef NDEBUG
|
|
3811
|
+
for (int k = 0; k < nc; k++) {
|
|
3812
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3813
|
+
GGML_UNUSED(x);
|
|
3814
|
+
assert(!isnan(x));
|
|
3815
|
+
assert(!isinf(x));
|
|
3816
|
+
}
|
|
3817
|
+
#endif
|
|
3818
|
+
}
|
|
3819
|
+
}
|
|
3820
|
+
|
|
3821
|
+
static void ggml_compute_forward_geglu_quick_f16(
|
|
3822
|
+
const ggml_compute_params * params,
|
|
3823
|
+
ggml_tensor * dst) {
|
|
3824
|
+
|
|
3825
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3826
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3827
|
+
char * src0_d = (char *) src0->data;
|
|
3828
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3829
|
+
const size_t src0_o = src0->nb[1];
|
|
3830
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3831
|
+
|
|
3832
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3833
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3834
|
+
|
|
3835
|
+
if (src1) {
|
|
3836
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3837
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3838
|
+
}
|
|
3839
|
+
|
|
3840
|
+
const int ith = params->ith;
|
|
3841
|
+
const int nth = params->nth;
|
|
3842
|
+
|
|
3843
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3844
|
+
const int nr = ggml_nrows(src0);
|
|
3845
|
+
|
|
3846
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3847
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3848
|
+
|
|
3849
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3850
|
+
|
|
3851
|
+
// rows per thread
|
|
3852
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3853
|
+
|
|
3854
|
+
// row range for this thread
|
|
3855
|
+
const int ir0 = dr*ith;
|
|
3856
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3857
|
+
|
|
3858
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3859
|
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
|
3860
|
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
|
3861
|
+
|
|
3862
|
+
if (!src1) {
|
|
3863
|
+
src0_p += swapped ? nc : 0;
|
|
3864
|
+
src1_p += swapped ? 0 : nc;
|
|
3865
|
+
}
|
|
3866
|
+
|
|
3867
|
+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
|
3868
|
+
|
|
3869
|
+
#ifndef NDEBUG
|
|
3870
|
+
for (int k = 0; k < nc; k++) {
|
|
3871
|
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
|
3872
|
+
const float v = GGML_FP16_TO_FP32(x);
|
|
3873
|
+
GGML_UNUSED(v);
|
|
3874
|
+
assert(!isnan(v));
|
|
3875
|
+
assert(!isinf(v));
|
|
3876
|
+
}
|
|
3877
|
+
#endif
|
|
3878
|
+
}
|
|
3879
|
+
}
|
|
3880
|
+
|
|
3881
|
+
static void ggml_compute_forward_geglu_quick(
|
|
3882
|
+
const ggml_compute_params * params,
|
|
3883
|
+
ggml_tensor * dst) {
|
|
3884
|
+
|
|
3885
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3886
|
+
|
|
3887
|
+
switch (src0->type) {
|
|
3888
|
+
case GGML_TYPE_F32:
|
|
3889
|
+
{
|
|
3890
|
+
ggml_compute_forward_geglu_quick_f32(params, dst);
|
|
3891
|
+
} break;
|
|
3892
|
+
case GGML_TYPE_F16:
|
|
3893
|
+
{
|
|
3894
|
+
ggml_compute_forward_geglu_quick_f16(params, dst);
|
|
3895
|
+
} break;
|
|
3896
|
+
default:
|
|
3897
|
+
{
|
|
3898
|
+
GGML_ABORT("fatal error");
|
|
3899
|
+
}
|
|
3900
|
+
}
|
|
3901
|
+
}
|
|
3902
|
+
|
|
3187
3903
|
// ggml_compute_forward_norm
|
|
3188
3904
|
|
|
3189
3905
|
static void ggml_compute_forward_norm_f32(
|
|
@@ -3927,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
|
|
|
3927
4643
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
3928
4644
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
3929
4645
|
|
|
3930
|
-
// scale factor
|
|
3931
|
-
float
|
|
3932
|
-
|
|
4646
|
+
float s; // scale factor
|
|
4647
|
+
float b; // bias
|
|
4648
|
+
|
|
4649
|
+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
|
|
4650
|
+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
|
|
3933
4651
|
|
|
3934
4652
|
const int ith = params->ith;
|
|
3935
4653
|
const int nth = params->nth;
|
|
@@ -3948,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(
|
|
|
3948
4666
|
|
|
3949
4667
|
const size_t nb1 = dst->nb[1];
|
|
3950
4668
|
|
|
3951
|
-
|
|
3952
|
-
|
|
3953
|
-
|
|
3954
|
-
|
|
4669
|
+
if (b == 0.0f) {
|
|
4670
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4671
|
+
if (dst->data != src0->data) {
|
|
4672
|
+
// src0 is same shape as dst => same indices
|
|
4673
|
+
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
|
|
4674
|
+
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
|
4675
|
+
}
|
|
4676
|
+
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
|
4677
|
+
}
|
|
4678
|
+
} else {
|
|
4679
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
4680
|
+
ggml_vec_mad1_f32(nc,
|
|
4681
|
+
(float *) ((char *) dst->data + i1*nb1),
|
|
4682
|
+
(float *) ((char *) src0->data + i1*nb1),
|
|
4683
|
+
s, b);
|
|
3955
4684
|
}
|
|
3956
|
-
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
|
3957
4685
|
}
|
|
3958
4686
|
}
|
|
3959
4687
|
|
|
@@ -4802,14 +5530,17 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
4802
5530
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
4803
5531
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
4804
5532
|
|
|
4805
|
-
// TODO: handle transposed/permuted matrices
|
|
4806
|
-
|
|
4807
5533
|
const int ith = params->ith;
|
|
4808
5534
|
const int nth = params->nth;
|
|
4809
5535
|
|
|
4810
5536
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
4811
5537
|
|
|
4812
|
-
|
|
5538
|
+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
|
5539
|
+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
|
5540
|
+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
|
5541
|
+
|
|
5542
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
|
5543
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
|
4813
5544
|
|
|
4814
5545
|
// TODO: is this supposed to be ceil instead of floor?
|
|
4815
5546
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
|
@@ -4819,68 +5550,66 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
4819
5550
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
4820
5551
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
4821
5552
|
|
|
4822
|
-
|
|
4823
|
-
const int nr = ggml_nrows(src0);
|
|
4824
|
-
|
|
4825
|
-
// rows per thread
|
|
4826
|
-
const int dr = (nr + nth - 1)/nth;
|
|
4827
|
-
|
|
4828
|
-
// row range for this thread
|
|
4829
|
-
const int ir0 = dr*ith;
|
|
4830
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
4831
|
-
|
|
4832
|
-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
|
5553
|
+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
4833
5554
|
|
|
4834
5555
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
4835
5556
|
|
|
4836
|
-
for (
|
|
4837
|
-
|
|
4838
|
-
|
|
4839
|
-
|
|
4840
|
-
|
|
4841
|
-
|
|
4842
|
-
|
|
4843
|
-
|
|
4844
|
-
|
|
4845
|
-
|
|
4846
|
-
|
|
4847
|
-
|
|
4848
|
-
|
|
4849
|
-
|
|
4850
|
-
|
|
4851
|
-
|
|
4852
|
-
|
|
4853
|
-
|
|
4854
|
-
|
|
4855
|
-
|
|
4856
|
-
|
|
4857
|
-
|
|
5557
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
5558
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
5559
|
+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
5560
|
+
const int64_t i11 = i01;
|
|
5561
|
+
const int64_t i12 = i02%ne12;
|
|
5562
|
+
const int64_t i13 = i03%ne13;
|
|
5563
|
+
|
|
5564
|
+
// ALiBi
|
|
5565
|
+
const uint32_t h = i02; // head
|
|
5566
|
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
5567
|
+
|
|
5568
|
+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
5569
|
+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
5570
|
+
|
|
5571
|
+
// broadcast the mask across rows
|
|
5572
|
+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5573
|
+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
|
5574
|
+
|
|
5575
|
+
ggml_vec_cpy_f32 (ne00, wp, sp);
|
|
5576
|
+
ggml_vec_scale_f32(ne00, wp, scale);
|
|
5577
|
+
if (mp_f32) {
|
|
5578
|
+
if (use_f16) {
|
|
5579
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5580
|
+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
|
5581
|
+
}
|
|
5582
|
+
} else {
|
|
5583
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5584
|
+
wp[i] += slope*mp_f32[i];
|
|
5585
|
+
}
|
|
5586
|
+
}
|
|
4858
5587
|
}
|
|
4859
|
-
}
|
|
4860
|
-
}
|
|
4861
5588
|
|
|
4862
5589
|
#ifndef NDEBUG
|
|
4863
|
-
|
|
4864
|
-
|
|
4865
|
-
|
|
4866
|
-
|
|
5590
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5591
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
|
5592
|
+
assert(!isnan(wp[i]));
|
|
5593
|
+
}
|
|
4867
5594
|
#endif
|
|
4868
5595
|
|
|
4869
|
-
|
|
4870
|
-
|
|
5596
|
+
float max = -INFINITY;
|
|
5597
|
+
ggml_vec_max_f32(ne00, &max, wp);
|
|
4871
5598
|
|
|
4872
|
-
|
|
4873
|
-
|
|
5599
|
+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
|
5600
|
+
assert(sum > 0.0);
|
|
4874
5601
|
|
|
4875
|
-
|
|
4876
|
-
|
|
5602
|
+
sum = 1.0/sum;
|
|
5603
|
+
ggml_vec_scale_f32(ne00, dp, sum);
|
|
4877
5604
|
|
|
4878
5605
|
#ifndef NDEBUG
|
|
4879
|
-
|
|
4880
|
-
|
|
4881
|
-
|
|
4882
|
-
|
|
5606
|
+
for (int i = 0; i < ne00; ++i) {
|
|
5607
|
+
assert(!isnan(dp[i]));
|
|
5608
|
+
assert(!isinf(dp[i]));
|
|
5609
|
+
}
|
|
4883
5610
|
#endif
|
|
5611
|
+
}
|
|
5612
|
+
}
|
|
4884
5613
|
}
|
|
4885
5614
|
}
|
|
4886
5615
|
|
|
@@ -6116,6 +6845,186 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
6116
6845
|
}
|
|
6117
6846
|
}
|
|
6118
6847
|
|
|
6848
|
+
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
6849
|
+
void * a, void * b, float * c) {
|
|
6850
|
+
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
|
6851
|
+
struct ggml_tensor src1 = {};
|
|
6852
|
+
src1.type = type;
|
|
6853
|
+
src1.ne[0] = k;
|
|
6854
|
+
src1.ne[1] = m;
|
|
6855
|
+
src1.ne[2] = 1;
|
|
6856
|
+
src1.ne[3] = 1;
|
|
6857
|
+
src1.nb[0] = traits->type_size;
|
|
6858
|
+
src1.nb[1] = k * traits->type_size;
|
|
6859
|
+
src1.nb[2] = src1.nb[1];
|
|
6860
|
+
src1.nb[3] = src1.nb[2];
|
|
6861
|
+
src1.data = a;
|
|
6862
|
+
|
|
6863
|
+
struct ggml_tensor src0 = {};
|
|
6864
|
+
src0.type = type;
|
|
6865
|
+
src0.ne[0] = k;
|
|
6866
|
+
src0.ne[1] = n;
|
|
6867
|
+
src0.ne[2] = 1;
|
|
6868
|
+
src0.ne[3] = 1;
|
|
6869
|
+
src0.nb[0] = traits->type_size;
|
|
6870
|
+
src0.nb[1] = k * traits->type_size;
|
|
6871
|
+
src0.nb[2] = src0.nb[1];
|
|
6872
|
+
src0.nb[3] = src0.nb[2];
|
|
6873
|
+
src0.data = b;
|
|
6874
|
+
|
|
6875
|
+
struct ggml_tensor dst = {};
|
|
6876
|
+
dst.ne[0] = n;
|
|
6877
|
+
dst.ne[1] = m;
|
|
6878
|
+
dst.ne[2] = 1;
|
|
6879
|
+
dst.ne[3] = 1;
|
|
6880
|
+
dst.nb[0] = sizeof(float);
|
|
6881
|
+
dst.nb[1] = n * sizeof(float);
|
|
6882
|
+
dst.nb[2] = dst.nb[1];
|
|
6883
|
+
dst.nb[3] = dst.nb[2];
|
|
6884
|
+
dst.data = c;
|
|
6885
|
+
dst.src[0] = &src0;
|
|
6886
|
+
dst.src[1] = &src1;
|
|
6887
|
+
|
|
6888
|
+
ggml_compute_forward_mul_mat(params, &dst);
|
|
6889
|
+
}
|
|
6890
|
+
|
|
6891
|
+
// ggml_compute_forward_conv_2d
|
|
6892
|
+
|
|
6893
|
+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
|
6894
|
+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
|
6895
|
+
const ggml_tensor * src, // [W, H, C, N]
|
|
6896
|
+
ggml_tensor * dst, // [OW, OH, OC, N]
|
|
6897
|
+
ggml_type kernel_type) {
|
|
6898
|
+
|
|
6899
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
|
6900
|
+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
|
6901
|
+
GGML_ASSERT(kernel->type == kernel_type);
|
|
6902
|
+
|
|
6903
|
+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
|
6904
|
+
|
|
6905
|
+
const int32_t stride_x = dst->op_params[0];
|
|
6906
|
+
const int32_t stride_y = dst->op_params[1];
|
|
6907
|
+
const int32_t pad_x = dst->op_params[2];
|
|
6908
|
+
const int32_t pad_y = dst->op_params[3];
|
|
6909
|
+
const int32_t dilation_x = dst->op_params[4];
|
|
6910
|
+
const int32_t dilation_y = dst->op_params[5];
|
|
6911
|
+
|
|
6912
|
+
const int64_t c_in = src->ne[2];
|
|
6913
|
+
const int64_t c_out = kernel->ne[3];
|
|
6914
|
+
GGML_ASSERT(c_in == kernel->ne[2]);
|
|
6915
|
+
|
|
6916
|
+
const int64_t src_w = src->ne[0];
|
|
6917
|
+
const int64_t src_h = src->ne[1];
|
|
6918
|
+
const int64_t knl_w = kernel->ne[0];
|
|
6919
|
+
const int64_t knl_h = kernel->ne[1];
|
|
6920
|
+
const int64_t dst_w = dst->ne[0];
|
|
6921
|
+
const int64_t dst_h = dst->ne[1];
|
|
6922
|
+
|
|
6923
|
+
const float * src_data = (float *) src->data;
|
|
6924
|
+
void * knl_data = kernel->data;
|
|
6925
|
+
float * dst_data = (float *) dst->data;
|
|
6926
|
+
|
|
6927
|
+
const int64_t knl_n = knl_w * knl_h * c_in;
|
|
6928
|
+
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
|
|
6929
|
+
|
|
6930
|
+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
|
|
6931
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
|
6932
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
|
6933
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
|
6934
|
+
|
|
6935
|
+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
|
6936
|
+
|
|
6937
|
+
void * tmp = params->wdata;
|
|
6938
|
+
|
|
6939
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
|
6940
|
+
|
|
6941
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
|
6942
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
|
|
6943
|
+
patch_total);
|
|
6944
|
+
const int64_t patch_n = patch_end_batch - patch_start_batch;
|
|
6945
|
+
|
|
6946
|
+
const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
|
|
6947
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
|
6948
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
|
6949
|
+
|
|
6950
|
+
//im2col for a patch
|
|
6951
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
|
6952
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
|
6953
|
+
const int64_t src_x = (p / dst_w) % dst_h;
|
|
6954
|
+
const int64_t src_y = p % dst_w;
|
|
6955
|
+
|
|
6956
|
+
const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
|
|
6957
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
|
|
6958
|
+
|
|
6959
|
+
for (int64_t ic = 0; ic < c_in; ++ic) {
|
|
6960
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
|
6961
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
|
6962
|
+
const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
|
|
6963
|
+
const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
|
|
6964
|
+
|
|
6965
|
+
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
|
|
6966
|
+
|
|
6967
|
+
float src_val;
|
|
6968
|
+
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
|
6969
|
+
src_val = 0.0f;
|
|
6970
|
+
} else {
|
|
6971
|
+
const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
|
|
6972
|
+
src_val = *src_ptr;
|
|
6973
|
+
}
|
|
6974
|
+
|
|
6975
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
|
6976
|
+
if (kernel_type == GGML_TYPE_F32) {
|
|
6977
|
+
*(float *) element_ptr = src_val;
|
|
6978
|
+
} else if (kernel_type == GGML_TYPE_F16) {
|
|
6979
|
+
*(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
|
6980
|
+
}
|
|
6981
|
+
}
|
|
6982
|
+
}
|
|
6983
|
+
}
|
|
6984
|
+
} // patches handled by this thread
|
|
6985
|
+
|
|
6986
|
+
ggml_barrier(params->threadpool);
|
|
6987
|
+
|
|
6988
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
|
|
6989
|
+
|
|
6990
|
+
GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
|
|
6991
|
+
|
|
6992
|
+
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
|
|
6993
|
+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
|
|
6994
|
+
|
|
6995
|
+
ggml_barrier(params->threadpool);
|
|
6996
|
+
|
|
6997
|
+
|
|
6998
|
+
//permute back [OC, N, OH, OW] to [N, OC, OH, OW]
|
|
6999
|
+
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
|
|
7000
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
|
7001
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
|
|
7002
|
+
|
|
7003
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
|
7004
|
+
const int64_t p = patch_start_batch + i;
|
|
7005
|
+
const int64_t batch_n = p / (dst_w * dst_h);
|
|
7006
|
+
const int64_t dst_y = (p / dst_w) % dst_h;
|
|
7007
|
+
const int64_t dst_x = p % dst_w;
|
|
7008
|
+
|
|
7009
|
+
for (int64_t oc = 0; oc < c_out; ++oc) {
|
|
7010
|
+
const float value = gemm_output[i * c_out + oc];
|
|
7011
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
|
|
7012
|
+
*dst_ptr = value;
|
|
7013
|
+
}
|
|
7014
|
+
}
|
|
7015
|
+
}
|
|
7016
|
+
}
|
|
7017
|
+
|
|
7018
|
+
void ggml_compute_forward_conv_2d(
|
|
7019
|
+
const ggml_compute_params * params,
|
|
7020
|
+
ggml_tensor * dst) {
|
|
7021
|
+
|
|
7022
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7023
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
7024
|
+
|
|
7025
|
+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
|
7026
|
+
}
|
|
7027
|
+
|
|
6119
7028
|
// ggml_compute_forward_conv_transpose_2d
|
|
6120
7029
|
|
|
6121
7030
|
void ggml_compute_forward_conv_transpose_2d(
|
|
@@ -6666,12 +7575,13 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
6666
7575
|
|
|
6667
7576
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
6668
7577
|
|
|
6669
|
-
|
|
6670
|
-
|
|
6671
|
-
|
|
6672
|
-
|
|
7578
|
+
float sf0 = (float)ne0/src0->ne[0];
|
|
7579
|
+
float sf1 = (float)ne1/src0->ne[1];
|
|
7580
|
+
float sf2 = (float)ne2/src0->ne[2];
|
|
7581
|
+
float sf3 = (float)ne3/src0->ne[3];
|
|
6673
7582
|
|
|
6674
|
-
const
|
|
7583
|
+
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
|
7584
|
+
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
6675
7585
|
|
|
6676
7586
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
6677
7587
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
@@ -6692,8 +7602,12 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
6692
7602
|
}
|
|
6693
7603
|
}
|
|
6694
7604
|
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
6695
|
-
|
|
6696
|
-
|
|
7605
|
+
float pixel_offset = 0.5f;
|
|
7606
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7607
|
+
pixel_offset = 0.0f;
|
|
7608
|
+
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
|
7609
|
+
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
|
7610
|
+
}
|
|
6697
7611
|
|
|
6698
7612
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
6699
7613
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7151,7 +8065,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7151
8065
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
7152
8066
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
7153
8067
|
|
|
7154
|
-
ggml_type
|
|
8068
|
+
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
|
7155
8069
|
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
|
7156
8070
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
|
7157
8071
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
|
@@ -7183,7 +8097,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
7183
8097
|
memset(VKQ32, 0, DV*sizeof(float));
|
|
7184
8098
|
}
|
|
7185
8099
|
|
|
7186
|
-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
|
8100
|
+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
|
7187
8101
|
|
|
7188
8102
|
// k indices
|
|
7189
8103
|
const int ik3 = iq3 / rk3;
|
|
@@ -7721,120 +8635,210 @@ void ggml_compute_forward_ssm_conv(
|
|
|
7721
8635
|
static void ggml_compute_forward_ssm_scan_f32(
|
|
7722
8636
|
const ggml_compute_params * params,
|
|
7723
8637
|
ggml_tensor * dst) {
|
|
7724
|
-
const ggml_tensor * src0 = dst->src[0]; // s
|
|
7725
|
-
const ggml_tensor * src1 = dst->src[1]; // x
|
|
7726
|
-
const ggml_tensor * src2 = dst->src[2]; // dt
|
|
7727
|
-
const ggml_tensor * src3 = dst->src[3]; // A
|
|
7728
|
-
const ggml_tensor * src4 = dst->src[4]; // B
|
|
7729
|
-
const ggml_tensor * src5 = dst->src[5]; // C
|
|
8638
|
+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
|
8639
|
+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
|
8640
|
+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
|
8641
|
+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
|
8642
|
+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8643
|
+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
|
8644
|
+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
|
7730
8645
|
|
|
7731
8646
|
const int ith = params->ith;
|
|
7732
8647
|
const int nth = params->nth;
|
|
7733
8648
|
|
|
7734
|
-
const int64_t nc
|
|
7735
|
-
const int64_t nr
|
|
7736
|
-
const int64_t
|
|
7737
|
-
const int64_t
|
|
8649
|
+
const int64_t nc = src0->ne[0]; // d_state
|
|
8650
|
+
const int64_t nr = src0->ne[1]; // dim
|
|
8651
|
+
const int64_t nh = src1->ne[1]; // n_head
|
|
8652
|
+
const int64_t ng = src4->ne[1];
|
|
8653
|
+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
|
8654
|
+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
|
8655
|
+
|
|
8656
|
+
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
|
8657
|
+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
|
7738
8658
|
|
|
7739
|
-
GGML_ASSERT(ggml_nelements(src1) +
|
|
8659
|
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
|
7740
8660
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
7741
8661
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
7742
8662
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
7743
8663
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
|
7744
8664
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
7745
8665
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
7746
|
-
|
|
7747
|
-
|
|
7748
|
-
|
|
7749
|
-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
|
7750
|
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
7751
|
-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
|
8666
|
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
8667
|
+
// allows optimizing the modulo since n_group should be a power of 2
|
|
8668
|
+
GGML_ASSERT((ng & -ng) == ng);
|
|
7752
8669
|
|
|
7753
|
-
//
|
|
7754
|
-
const int
|
|
8670
|
+
// heads per thread
|
|
8671
|
+
const int dh = (nh + nth - 1)/nth;
|
|
7755
8672
|
|
|
7756
|
-
//
|
|
7757
|
-
const int
|
|
7758
|
-
const int
|
|
7759
|
-
|
|
8673
|
+
// head range for this thread
|
|
8674
|
+
const int ih0 = dh*ith;
|
|
8675
|
+
const int ih1 = MIN(ih0 + dh, nh);
|
|
8676
|
+
|
|
8677
|
+
const int32_t * ids = (const int32_t *) src6->data;
|
|
7760
8678
|
|
|
7761
|
-
|
|
7762
|
-
|
|
7763
|
-
|
|
7764
|
-
|
|
7765
|
-
|
|
7766
|
-
|
|
7767
|
-
|
|
7768
|
-
|
|
7769
|
-
|
|
7770
|
-
|
|
7771
|
-
|
|
7772
|
-
|
|
7773
|
-
|
|
7774
|
-
|
|
7775
|
-
|
|
7776
|
-
//
|
|
7777
|
-
for (int
|
|
7778
|
-
|
|
7779
|
-
float
|
|
7780
|
-
|
|
7781
|
-
|
|
7782
|
-
|
|
7783
|
-
|
|
7784
|
-
|
|
7785
|
-
|
|
7786
|
-
|
|
7787
|
-
|
|
7788
|
-
|
|
7789
|
-
|
|
7790
|
-
|
|
7791
|
-
|
|
7792
|
-
|
|
7793
|
-
|
|
7794
|
-
|
|
7795
|
-
|
|
7796
|
-
|
|
7797
|
-
|
|
8679
|
+
for (int i3 = 0; i3 < ns; ++i3) {
|
|
8680
|
+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
|
8681
|
+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
|
8682
|
+
|
|
8683
|
+
for (int i2 = 0; i2 < nt; ++i2) {
|
|
8684
|
+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
|
8685
|
+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
|
8686
|
+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
|
8687
|
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
|
8688
|
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
|
8689
|
+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
|
8690
|
+
|
|
8691
|
+
if (src3->ne[0] == 1) {
|
|
8692
|
+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
|
8693
|
+
|
|
8694
|
+
// n_head
|
|
8695
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8696
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8697
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8698
|
+
const float dA = expf(dt_soft_plus * A[h]);
|
|
8699
|
+
|
|
8700
|
+
// dim
|
|
8701
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8702
|
+
const int ii = i1 + h*nr;
|
|
8703
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8704
|
+
float sumf = 0.0f;
|
|
8705
|
+
#if defined(GGML_SIMD)
|
|
8706
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8707
|
+
const int ggml_f32_epr = svcntw();
|
|
8708
|
+
const int ggml_f32_step = 1 * ggml_f32_epr;
|
|
8709
|
+
|
|
8710
|
+
const int np = (nc & ~(ggml_f32_step - 1));
|
|
8711
|
+
|
|
8712
|
+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
|
8713
|
+
|
|
8714
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8715
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8716
|
+
|
|
8717
|
+
for (int i = 0; i < np; i += ggml_f32_step) {
|
|
8718
|
+
// TODO: maybe unroll more?
|
|
8719
|
+
for (int j = 0; j < 1; j++) {
|
|
8720
|
+
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
|
8721
|
+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8722
|
+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
|
8723
|
+
|
|
8724
|
+
t0 = GGML_F32_VEC_MUL(t0, adA);
|
|
8725
|
+
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
|
8726
|
+
|
|
8727
|
+
t0 = GGML_F32_VEC_ADD(t0, t1);
|
|
8728
|
+
|
|
8729
|
+
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
|
8730
|
+
|
|
8731
|
+
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
|
8732
|
+
}
|
|
8733
|
+
}
|
|
8734
|
+
|
|
8735
|
+
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
|
8736
|
+
#else
|
|
8737
|
+
const int np = (nc & ~(GGML_F32_STEP - 1));
|
|
8738
|
+
|
|
8739
|
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
|
8740
|
+
|
|
8741
|
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
|
8742
|
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
|
8743
|
+
|
|
8744
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
|
8745
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
8746
|
+
GGML_F32_VEC az[GGML_F32_ARR];
|
|
8747
|
+
|
|
8748
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
|
8749
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
|
8750
|
+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
|
8751
|
+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8752
|
+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
|
8753
|
+
|
|
8754
|
+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
|
8755
|
+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
|
8756
|
+
|
|
8757
|
+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
|
8758
|
+
|
|
8759
|
+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
|
8760
|
+
|
|
8761
|
+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
|
8762
|
+
}
|
|
8763
|
+
}
|
|
8764
|
+
|
|
8765
|
+
// reduce sum0..sum3 to sum0
|
|
8766
|
+
GGML_F32_VEC_REDUCE(sumf, sum);
|
|
8767
|
+
#endif
|
|
8768
|
+
#else
|
|
8769
|
+
const int np = 0;
|
|
8770
|
+
#endif
|
|
8771
|
+
// d_state
|
|
8772
|
+
for (int i0 = np; i0 < nc; ++i0) {
|
|
8773
|
+
const int i = i0 + ii*nc;
|
|
8774
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8775
|
+
// state = prev_state * dA + dB * x
|
|
8776
|
+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
|
8777
|
+
// y = rowwise_dotprod(state, C)
|
|
8778
|
+
sumf += state * C[ig];
|
|
8779
|
+
s[i] = state;
|
|
8780
|
+
}
|
|
8781
|
+
y[ii] = sumf;
|
|
7798
8782
|
}
|
|
7799
|
-
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
7800
8783
|
}
|
|
7801
|
-
}
|
|
7802
|
-
|
|
7803
|
-
|
|
7804
|
-
|
|
7805
|
-
|
|
7806
|
-
|
|
7807
|
-
|
|
7808
|
-
|
|
7809
|
-
|
|
7810
|
-
|
|
7811
|
-
|
|
7812
|
-
|
|
7813
|
-
|
|
7814
|
-
|
|
7815
|
-
|
|
7816
|
-
|
|
7817
|
-
|
|
7818
|
-
|
|
7819
|
-
|
|
7820
|
-
|
|
7821
|
-
|
|
7822
|
-
|
|
7823
|
-
|
|
7824
|
-
|
|
7825
|
-
|
|
7826
|
-
|
|
7827
|
-
|
|
7828
|
-
|
|
7829
|
-
|
|
7830
|
-
|
|
7831
|
-
|
|
8784
|
+
} else {
|
|
8785
|
+
// Mamba-1 has an element-wise decay factor for the states
|
|
8786
|
+
|
|
8787
|
+
// n_head
|
|
8788
|
+
for (int h = ih0; h < ih1; ++h) {
|
|
8789
|
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8790
|
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8791
|
+
|
|
8792
|
+
// dim
|
|
8793
|
+
for (int i1 = 0; i1 < nr; ++i1) {
|
|
8794
|
+
const int ii = i1 + h*nr;
|
|
8795
|
+
const float x_dt = x[ii] * dt_soft_plus;
|
|
8796
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
8797
|
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
|
8798
|
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
|
8799
|
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
|
8800
|
+
|
|
8801
|
+
// d_state
|
|
8802
|
+
// TODO: what happens when (d_state % svcntw()) != 0?
|
|
8803
|
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
8804
|
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
|
8805
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
|
8806
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
|
8807
|
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
|
8808
|
+
|
|
8809
|
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
8810
|
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
|
8811
|
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
|
8812
|
+
|
|
8813
|
+
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
|
8814
|
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
|
8815
|
+
|
|
8816
|
+
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
|
8817
|
+
}
|
|
8818
|
+
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
|
8819
|
+
#else
|
|
8820
|
+
float sumf = 0.0f;
|
|
8821
|
+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
|
8822
|
+
// and also because expf is used within the loop.
|
|
8823
|
+
// d_state
|
|
8824
|
+
for (int i0 = 0; i0 < nc; ++i0) {
|
|
8825
|
+
const int i = i0 + ii*nc;
|
|
8826
|
+
const int ig = i0 + (h & (ng - 1))*nc;
|
|
8827
|
+
// state = prev_state * dA + dB * x
|
|
8828
|
+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
|
8829
|
+
// y = rowwise_dotprod(state, C)
|
|
8830
|
+
sumf += state * C[ig];
|
|
8831
|
+
s[i] = state;
|
|
8832
|
+
}
|
|
8833
|
+
y[ii] = sumf;
|
|
8834
|
+
#endif
|
|
7832
8835
|
}
|
|
7833
|
-
y[i1] = sumf;
|
|
7834
8836
|
}
|
|
7835
8837
|
}
|
|
8838
|
+
// use the output as the source when it's not the first token-wise iteration
|
|
8839
|
+
s0 = s;
|
|
7836
8840
|
}
|
|
7837
|
-
|
|
8841
|
+
}
|
|
7838
8842
|
}
|
|
7839
8843
|
|
|
7840
8844
|
void ggml_compute_forward_ssm_scan(
|
|
@@ -8052,6 +9056,42 @@ void ggml_compute_forward_unary(
|
|
|
8052
9056
|
}
|
|
8053
9057
|
}
|
|
8054
9058
|
|
|
9059
|
+
//ggml_compute_forward_glu
|
|
9060
|
+
|
|
9061
|
+
void ggml_compute_forward_glu(
|
|
9062
|
+
const ggml_compute_params * params,
|
|
9063
|
+
ggml_tensor * dst) {
|
|
9064
|
+
|
|
9065
|
+
const ggml_glu_op op = ggml_get_glu_op(dst);
|
|
9066
|
+
|
|
9067
|
+
switch (op) {
|
|
9068
|
+
case GGML_GLU_OP_REGLU:
|
|
9069
|
+
{
|
|
9070
|
+
ggml_compute_forward_reglu(params, dst);
|
|
9071
|
+
} break;
|
|
9072
|
+
case GGML_GLU_OP_GEGLU:
|
|
9073
|
+
{
|
|
9074
|
+
ggml_compute_forward_geglu(params, dst);
|
|
9075
|
+
} break;
|
|
9076
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9077
|
+
{
|
|
9078
|
+
ggml_compute_forward_swiglu(params, dst);
|
|
9079
|
+
} break;
|
|
9080
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9081
|
+
{
|
|
9082
|
+
ggml_compute_forward_geglu_erf(params, dst);
|
|
9083
|
+
} break;
|
|
9084
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9085
|
+
{
|
|
9086
|
+
ggml_compute_forward_geglu_quick(params, dst);
|
|
9087
|
+
} break;
|
|
9088
|
+
default:
|
|
9089
|
+
{
|
|
9090
|
+
GGML_ABORT("fatal error");
|
|
9091
|
+
}
|
|
9092
|
+
}
|
|
9093
|
+
}
|
|
9094
|
+
|
|
8055
9095
|
// ggml_compute_forward_get_rel_pos
|
|
8056
9096
|
|
|
8057
9097
|
static void ggml_compute_forward_get_rel_pos_f16(
|