@novastera-oss/llamarn 0.2.9 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +17 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.h +4 -0
- package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +0 -40
- package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
- package/cpp/llama.cpp/src/llama-arch.h +18 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
- package/cpp/llama.cpp/src/llama-batch.h +8 -1
- package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
- package/cpp/llama.cpp/src/llama-graph.h +47 -60
- package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
- package/cpp/llama.cpp/src/llama-hparams.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
- package/cpp/llama.cpp/src/llama-model.h +18 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
- package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
- package/cpp/llama.cpp/src/llama-vocab.h +41 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +4 -0
- package/ios/include/llama.h +0 -40
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
|
109
109
|
}
|
|
110
110
|
|
|
111
111
|
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
112
|
+
#pragma METAL fp math_mode(safe)
|
|
112
113
|
float amax = 0.0f; // absolute max
|
|
113
114
|
float max = 0.0f;
|
|
114
115
|
|
|
@@ -138,6 +139,7 @@ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
|
138
139
|
}
|
|
139
140
|
|
|
140
141
|
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
|
142
|
+
#pragma METAL fp math_mode(safe)
|
|
141
143
|
float min = FLT_MAX;
|
|
142
144
|
float max = -FLT_MAX;
|
|
143
145
|
|
|
@@ -166,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
|
|
166
168
|
}
|
|
167
169
|
|
|
168
170
|
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
|
171
|
+
#pragma METAL fp math_mode(safe)
|
|
169
172
|
float amax = 0.0f; // absolute max
|
|
170
173
|
float max = 0.0f;
|
|
171
174
|
|
|
@@ -203,6 +206,7 @@ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
|
|
203
206
|
}
|
|
204
207
|
|
|
205
208
|
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
|
209
|
+
#pragma METAL fp math_mode(safe)
|
|
206
210
|
float max = src[0];
|
|
207
211
|
float min = src[0];
|
|
208
212
|
|
|
@@ -239,6 +243,7 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
|
|
239
243
|
}
|
|
240
244
|
|
|
241
245
|
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
|
246
|
+
#pragma METAL fp math_mode(safe)
|
|
242
247
|
float amax = 0.0f; // absolute max
|
|
243
248
|
float max = 0.0f;
|
|
244
249
|
|
|
@@ -458,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|
|
458
463
|
}
|
|
459
464
|
|
|
460
465
|
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
|
466
|
+
#pragma METAL fp math_mode(safe)
|
|
461
467
|
float amax = 0.0f; // absolute max
|
|
462
468
|
|
|
463
469
|
for (int j = 0; j < QK8_0; j++) {
|
|
@@ -1008,16 +1014,18 @@ kernel void kernel_scale(
|
|
|
1008
1014
|
device const float * src0,
|
|
1009
1015
|
device float * dst,
|
|
1010
1016
|
constant float & scale,
|
|
1017
|
+
constant float & bias,
|
|
1011
1018
|
uint tpig[[thread_position_in_grid]]) {
|
|
1012
|
-
dst[tpig] = src0[tpig] * scale;
|
|
1019
|
+
dst[tpig] = src0[tpig] * scale + bias;
|
|
1013
1020
|
}
|
|
1014
1021
|
|
|
1015
1022
|
kernel void kernel_scale_4(
|
|
1016
1023
|
device const float4 * src0,
|
|
1017
1024
|
device float4 * dst,
|
|
1018
1025
|
constant float & scale,
|
|
1026
|
+
constant float & bias,
|
|
1019
1027
|
uint tpig[[thread_position_in_grid]]) {
|
|
1020
|
-
dst[tpig] = src0[tpig] * scale;
|
|
1028
|
+
dst[tpig] = src0[tpig] * scale + bias;
|
|
1021
1029
|
}
|
|
1022
1030
|
|
|
1023
1031
|
kernel void kernel_clamp(
|
|
@@ -1191,6 +1199,114 @@ kernel void kernel_neg(
|
|
|
1191
1199
|
dst[tpig] = -src0[tpig];
|
|
1192
1200
|
}
|
|
1193
1201
|
|
|
1202
|
+
kernel void kernel_reglu(
|
|
1203
|
+
device const char * src0,
|
|
1204
|
+
device const char * src1,
|
|
1205
|
+
device char * dst,
|
|
1206
|
+
constant ggml_metal_kargs_glu & args,
|
|
1207
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1208
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1209
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1210
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1211
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1212
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1213
|
+
|
|
1214
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1215
|
+
const float x0 = src0_row[i0];
|
|
1216
|
+
const float x1 = src1_row[i0];
|
|
1217
|
+
|
|
1218
|
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
|
1219
|
+
}
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
kernel void kernel_geglu(
|
|
1223
|
+
device const char * src0,
|
|
1224
|
+
device const char * src1,
|
|
1225
|
+
device char * dst,
|
|
1226
|
+
constant ggml_metal_kargs_glu & args,
|
|
1227
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1228
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1229
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1230
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1231
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1232
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1233
|
+
|
|
1234
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1235
|
+
const float x0 = src0_row[i0];
|
|
1236
|
+
const float x1 = src1_row[i0];
|
|
1237
|
+
|
|
1238
|
+
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
|
1239
|
+
|
|
1240
|
+
dst_row[i0] = gelu*x1;
|
|
1241
|
+
}
|
|
1242
|
+
}
|
|
1243
|
+
|
|
1244
|
+
kernel void kernel_swiglu(
|
|
1245
|
+
device const char * src0,
|
|
1246
|
+
device const char * src1,
|
|
1247
|
+
device char * dst,
|
|
1248
|
+
constant ggml_metal_kargs_glu & args,
|
|
1249
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1250
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1251
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1252
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1253
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1254
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1255
|
+
|
|
1256
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1257
|
+
const float x0 = src0_row[i0];
|
|
1258
|
+
const float x1 = src1_row[i0];
|
|
1259
|
+
|
|
1260
|
+
const float silu = x0 / (1.0f + exp(-x0));
|
|
1261
|
+
|
|
1262
|
+
dst_row[i0] = silu*x1;
|
|
1263
|
+
}
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
kernel void kernel_geglu_erf(
|
|
1267
|
+
device const char * src0,
|
|
1268
|
+
device const char * src1,
|
|
1269
|
+
device char * dst,
|
|
1270
|
+
constant ggml_metal_kargs_glu & args,
|
|
1271
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1272
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1273
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1274
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1275
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1276
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1277
|
+
|
|
1278
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1279
|
+
const float x0 = src0_row[i0];
|
|
1280
|
+
const float x1 = src1_row[i0];
|
|
1281
|
+
|
|
1282
|
+
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
|
1283
|
+
|
|
1284
|
+
dst_row[i0] = gelu_erf*x1;
|
|
1285
|
+
}
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
kernel void kernel_geglu_quick(
|
|
1289
|
+
device const char * src0,
|
|
1290
|
+
device const char * src1,
|
|
1291
|
+
device char * dst,
|
|
1292
|
+
constant ggml_metal_kargs_glu & args,
|
|
1293
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1294
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1295
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1296
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1297
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1298
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1299
|
+
|
|
1300
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1301
|
+
const float x0 = src0_row[i0];
|
|
1302
|
+
const float x1 = src1_row[i0];
|
|
1303
|
+
|
|
1304
|
+
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
|
1305
|
+
|
|
1306
|
+
dst_row[i0] = gelu_quick*x1;
|
|
1307
|
+
}
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1194
1310
|
template <bool norm>
|
|
1195
1311
|
kernel void kernel_sum_rows(
|
|
1196
1312
|
constant ggml_metal_kargs_sum_rows & args,
|
|
@@ -1253,24 +1369,28 @@ kernel void kernel_soft_max(
|
|
|
1253
1369
|
device char * dst,
|
|
1254
1370
|
constant ggml_metal_kargs_soft_max & args,
|
|
1255
1371
|
threadgroup float * buf [[threadgroup(0)]],
|
|
1256
|
-
|
|
1257
|
-
|
|
1372
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1373
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1258
1374
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
1259
1375
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1260
|
-
|
|
1261
|
-
const
|
|
1262
|
-
const
|
|
1263
|
-
const
|
|
1376
|
+
uint3 tptg[[threads_per_threadgroup]]) {
|
|
1377
|
+
const int32_t i03 = tgpig.z;
|
|
1378
|
+
const int32_t i02 = tgpig.y;
|
|
1379
|
+
const int32_t i01 = tgpig.x;
|
|
1380
|
+
|
|
1381
|
+
const int32_t i13 = i03%args.ne13;
|
|
1382
|
+
const int32_t i12 = i02%args.ne12;
|
|
1383
|
+
const int32_t i11 = i01;
|
|
1264
1384
|
|
|
1265
|
-
device const float * psrc0 =
|
|
1266
|
-
device const T * pmask = src1 != src0 ? (device const
|
|
1267
|
-
device float * pdst =
|
|
1385
|
+
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
1386
|
+
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
|
1387
|
+
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
1268
1388
|
|
|
1269
1389
|
float slope = 1.0f;
|
|
1270
1390
|
|
|
1271
1391
|
// ALiBi
|
|
1272
1392
|
if (args.max_bias > 0.0f) {
|
|
1273
|
-
const
|
|
1393
|
+
const int32_t h = i02;
|
|
1274
1394
|
|
|
1275
1395
|
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
1276
1396
|
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
@@ -1281,13 +1401,13 @@ kernel void kernel_soft_max(
|
|
|
1281
1401
|
// parallel max
|
|
1282
1402
|
float lmax = -INFINITY;
|
|
1283
1403
|
|
|
1284
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1404
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1285
1405
|
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
|
1286
1406
|
}
|
|
1287
1407
|
|
|
1288
1408
|
// find the max value in the block
|
|
1289
1409
|
float max_val = simd_max(lmax);
|
|
1290
|
-
if (
|
|
1410
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1291
1411
|
if (sgitg == 0) {
|
|
1292
1412
|
buf[tiisg] = -INFINITY;
|
|
1293
1413
|
}
|
|
@@ -1306,7 +1426,7 @@ kernel void kernel_soft_max(
|
|
|
1306
1426
|
|
|
1307
1427
|
// parallel sum
|
|
1308
1428
|
float lsum = 0.0f;
|
|
1309
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1429
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1310
1430
|
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
|
1311
1431
|
lsum += exp_psrc0;
|
|
1312
1432
|
pdst[i00] = exp_psrc0;
|
|
@@ -1318,7 +1438,7 @@ kernel void kernel_soft_max(
|
|
|
1318
1438
|
|
|
1319
1439
|
float sum = simd_sum(lsum);
|
|
1320
1440
|
|
|
1321
|
-
if (
|
|
1441
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1322
1442
|
if (sgitg == 0) {
|
|
1323
1443
|
buf[tiisg] = 0.0f;
|
|
1324
1444
|
}
|
|
@@ -1337,7 +1457,7 @@ kernel void kernel_soft_max(
|
|
|
1337
1457
|
|
|
1338
1458
|
const float inv_sum = 1.0f/sum;
|
|
1339
1459
|
|
|
1340
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1460
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1341
1461
|
pdst[i00] *= inv_sum;
|
|
1342
1462
|
}
|
|
1343
1463
|
}
|
|
@@ -1349,23 +1469,27 @@ kernel void kernel_soft_max_4(
|
|
|
1349
1469
|
device char * dst,
|
|
1350
1470
|
constant ggml_metal_kargs_soft_max & args,
|
|
1351
1471
|
threadgroup float * buf [[threadgroup(0)]],
|
|
1352
|
-
|
|
1353
|
-
|
|
1472
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1473
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1354
1474
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
1355
1475
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1356
|
-
|
|
1357
|
-
const
|
|
1358
|
-
const
|
|
1359
|
-
const
|
|
1476
|
+
uint3 tptg[[threads_per_threadgroup]]) {
|
|
1477
|
+
const int32_t i03 = tgpig.z;
|
|
1478
|
+
const int32_t i02 = tgpig.y;
|
|
1479
|
+
const int32_t i01 = tgpig.x;
|
|
1360
1480
|
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1481
|
+
const int32_t i13 = i03%args.ne13;
|
|
1482
|
+
const int32_t i12 = i02%args.ne12;
|
|
1483
|
+
const int32_t i11 = i01;
|
|
1484
|
+
|
|
1485
|
+
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
1486
|
+
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
|
1487
|
+
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
1364
1488
|
|
|
1365
1489
|
float slope = 1.0f;
|
|
1366
1490
|
|
|
1367
1491
|
if (args.max_bias > 0.0f) {
|
|
1368
|
-
const
|
|
1492
|
+
const int32_t h = i02;
|
|
1369
1493
|
|
|
1370
1494
|
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
1371
1495
|
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
@@ -1376,14 +1500,14 @@ kernel void kernel_soft_max_4(
|
|
|
1376
1500
|
// parallel max
|
|
1377
1501
|
float4 lmax4 = -INFINITY;
|
|
1378
1502
|
|
|
1379
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1503
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1380
1504
|
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
|
1381
1505
|
}
|
|
1382
1506
|
|
|
1383
1507
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
1384
1508
|
|
|
1385
1509
|
float max_val = simd_max(lmax);
|
|
1386
|
-
if (
|
|
1510
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1387
1511
|
if (sgitg == 0) {
|
|
1388
1512
|
buf[tiisg] = -INFINITY;
|
|
1389
1513
|
}
|
|
@@ -1402,7 +1526,7 @@ kernel void kernel_soft_max_4(
|
|
|
1402
1526
|
|
|
1403
1527
|
// parallel sum
|
|
1404
1528
|
float4 lsum4 = 0.0f;
|
|
1405
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1529
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1406
1530
|
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
|
1407
1531
|
lsum4 += exp_psrc4;
|
|
1408
1532
|
pdst4[i00] = exp_psrc4;
|
|
@@ -1416,7 +1540,7 @@ kernel void kernel_soft_max_4(
|
|
|
1416
1540
|
|
|
1417
1541
|
float sum = simd_sum(lsum);
|
|
1418
1542
|
|
|
1419
|
-
if (
|
|
1543
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1420
1544
|
if (sgitg == 0) {
|
|
1421
1545
|
buf[tiisg] = 0.0f;
|
|
1422
1546
|
}
|
|
@@ -1435,7 +1559,7 @@ kernel void kernel_soft_max_4(
|
|
|
1435
1559
|
|
|
1436
1560
|
const float inv_sum = 1.0f/sum;
|
|
1437
1561
|
|
|
1438
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1562
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1439
1563
|
pdst4[i00] *= inv_sum;
|
|
1440
1564
|
}
|
|
1441
1565
|
}
|
|
@@ -1521,7 +1645,7 @@ kernel void kernel_ssm_conv_f32(
|
|
|
1521
1645
|
x[0] = sumf;
|
|
1522
1646
|
}
|
|
1523
1647
|
|
|
1524
|
-
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
|
1648
|
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
|
1525
1649
|
kernel void kernel_ssm_scan_f32(
|
|
1526
1650
|
device const void * src0,
|
|
1527
1651
|
device const void * src1,
|
|
@@ -1529,46 +1653,119 @@ kernel void kernel_ssm_scan_f32(
|
|
|
1529
1653
|
device const void * src3,
|
|
1530
1654
|
device const void * src4,
|
|
1531
1655
|
device const void * src5,
|
|
1656
|
+
device const void * src6,
|
|
1532
1657
|
device float * dst,
|
|
1533
1658
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
1534
1659
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1535
1660
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1536
1661
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1537
|
-
const int64_t
|
|
1538
|
-
const int64_t
|
|
1662
|
+
const int64_t i1 = 0;
|
|
1663
|
+
const int64_t ir = tgpig.x; // current head
|
|
1664
|
+
const int64_t i3 = tgpig.y; // current seq
|
|
1665
|
+
|
|
1666
|
+
const uint64_t nb00 = sizeof(float);
|
|
1667
|
+
const uint64_t nb10 = sizeof(float);
|
|
1668
|
+
const uint64_t nb20 = sizeof(float);
|
|
1539
1669
|
|
|
1540
1670
|
const int64_t nc = args.d_state;
|
|
1541
|
-
|
|
1671
|
+
const int64_t nr = args.d_inner;
|
|
1672
|
+
const int64_t nh = args.n_head;
|
|
1673
|
+
const int64_t ng = args.n_group;
|
|
1542
1674
|
const int64_t n_t = args.n_seq_tokens;
|
|
1543
|
-
|
|
1675
|
+
|
|
1676
|
+
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
|
1677
|
+
|
|
1678
|
+
device const int32_t * ids = (device const int32_t *) src6;
|
|
1679
|
+
|
|
1680
|
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
1681
|
+
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
1544
1682
|
|
|
1545
1683
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
1546
|
-
device const float *
|
|
1547
|
-
device const float *
|
|
1548
|
-
device const float *
|
|
1549
|
-
device const float *
|
|
1550
|
-
device const float *
|
|
1551
|
-
device
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
if (i2 > 0) {
|
|
1556
|
-
s0 = s;
|
|
1557
|
-
}
|
|
1558
|
-
|
|
1559
|
-
// i1 == 0
|
|
1560
|
-
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
1561
|
-
float x_dt = x[0] * dt_soft_plus;
|
|
1684
|
+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
|
1685
|
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
|
1686
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
|
1687
|
+
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
|
1688
|
+
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
|
1689
|
+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
|
1690
|
+
|
|
1691
|
+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
1692
|
+
const float x_dt = x[0] * dt_soft_plus;
|
|
1562
1693
|
float sumf = 0.0f;
|
|
1563
1694
|
|
|
1564
1695
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
1565
|
-
int64_t i = i0;
|
|
1566
|
-
float state = (s0[i] * exp(dt_soft_plus * A[
|
|
1696
|
+
const int64_t i = i0 + i1*nc;
|
|
1697
|
+
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
|
1567
1698
|
sumf += state * C[i0];
|
|
1568
1699
|
s[i] = state;
|
|
1569
1700
|
}
|
|
1570
1701
|
|
|
1571
1702
|
y[0] = sumf;
|
|
1703
|
+
|
|
1704
|
+
// recurse
|
|
1705
|
+
s0 = s;
|
|
1706
|
+
}
|
|
1707
|
+
}
|
|
1708
|
+
|
|
1709
|
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
1710
|
+
// TODO: optimize (e.g. by parallelizing over d_state)
|
|
1711
|
+
kernel void kernel_ssm_scan_f32_group(
|
|
1712
|
+
device const void * src0,
|
|
1713
|
+
device const void * src1,
|
|
1714
|
+
device const void * src2,
|
|
1715
|
+
device const void * src3,
|
|
1716
|
+
device const void * src4,
|
|
1717
|
+
device const void * src5,
|
|
1718
|
+
device const void * src6,
|
|
1719
|
+
device float * dst,
|
|
1720
|
+
constant ggml_metal_kargs_ssm_scan & args,
|
|
1721
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1722
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1723
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1724
|
+
const int64_t i1 = tgpig.x;
|
|
1725
|
+
const int64_t ir = tgpig.y; // current head
|
|
1726
|
+
const int64_t i3 = tgpig.z; // current seq
|
|
1727
|
+
|
|
1728
|
+
const uint64_t nb00 = sizeof(float);
|
|
1729
|
+
const uint64_t nb10 = sizeof(float);
|
|
1730
|
+
const uint64_t nb20 = sizeof(float);
|
|
1731
|
+
|
|
1732
|
+
const int64_t nc = args.d_state;
|
|
1733
|
+
const int64_t nr = args.d_inner;
|
|
1734
|
+
const int64_t nh = args.n_head;
|
|
1735
|
+
const int64_t ng = args.n_group;
|
|
1736
|
+
const int64_t n_t = args.n_seq_tokens;
|
|
1737
|
+
|
|
1738
|
+
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
|
1739
|
+
|
|
1740
|
+
device const int32_t * ids = (device const int32_t *) src6;
|
|
1741
|
+
|
|
1742
|
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
1743
|
+
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
1744
|
+
|
|
1745
|
+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
1746
|
+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
|
1747
|
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
|
1748
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
|
1749
|
+
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
|
1750
|
+
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
|
1751
|
+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
|
1752
|
+
|
|
1753
|
+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
1754
|
+
const float x_dt = x[0] * dt_soft_plus;
|
|
1755
|
+
const float dA = exp(dt_soft_plus * A[0]);
|
|
1756
|
+
float sumf = 0.0f;
|
|
1757
|
+
|
|
1758
|
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
1759
|
+
const int64_t i = i0 + i1*nc;
|
|
1760
|
+
const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
|
1761
|
+
sumf += state * C[i0];
|
|
1762
|
+
s[i] = state;
|
|
1763
|
+
}
|
|
1764
|
+
|
|
1765
|
+
y[0] = sumf;
|
|
1766
|
+
|
|
1767
|
+
// recurse
|
|
1768
|
+
s0 = s;
|
|
1572
1769
|
}
|
|
1573
1770
|
}
|
|
1574
1771
|
|
|
@@ -3709,7 +3906,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
3709
3906
|
// load the mask in shared memory
|
|
3710
3907
|
#pragma unroll(Q)
|
|
3711
3908
|
for (short j = 0; j < Q; ++j) {
|
|
3712
|
-
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
|
3909
|
+
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
|
3713
3910
|
|
|
3714
3911
|
const float m = pm[ic + tiisg];
|
|
3715
3912
|
|
|
@@ -4195,7 +4392,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
4195
4392
|
const bool has_mask = mask != q;
|
|
4196
4393
|
|
|
4197
4394
|
// pointer to the mask
|
|
4198
|
-
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
|
4395
|
+
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
|
4199
4396
|
|
|
4200
4397
|
float slope = 1.0f;
|
|
4201
4398
|
|
|
@@ -65,6 +65,7 @@ set(GGML_OPENCL_KERNELS
|
|
|
65
65
|
gemv_noshuffle_general
|
|
66
66
|
gemv_noshuffle
|
|
67
67
|
get_rows
|
|
68
|
+
glu
|
|
68
69
|
group_norm
|
|
69
70
|
im2col_f32
|
|
70
71
|
im2col_f16
|
|
@@ -87,6 +88,7 @@ set(GGML_OPENCL_KERNELS
|
|
|
87
88
|
rms_norm
|
|
88
89
|
rope
|
|
89
90
|
scale
|
|
91
|
+
set_rows
|
|
90
92
|
sigmoid
|
|
91
93
|
silu
|
|
92
94
|
softmax_4_f32
|
|
@@ -102,6 +104,7 @@ set(GGML_OPENCL_KERNELS
|
|
|
102
104
|
tanh
|
|
103
105
|
pad
|
|
104
106
|
repeat
|
|
107
|
+
mul_mat_f16_f32
|
|
105
108
|
)
|
|
106
109
|
|
|
107
110
|
foreach (K ${GGML_OPENCL_KERNELS})
|