@novastera-oss/llamarn 0.2.9 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +17 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.h +4 -0
- package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +0 -40
- package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
- package/cpp/llama.cpp/src/llama-arch.h +18 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
- package/cpp/llama.cpp/src/llama-batch.h +8 -1
- package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
- package/cpp/llama.cpp/src/llama-graph.h +47 -60
- package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
- package/cpp/llama.cpp/src/llama-hparams.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
- package/cpp/llama.cpp/src/llama-model.h +18 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
- package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
- package/cpp/llama.cpp/src/llama-vocab.h +41 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +4 -0
- package/ios/include/llama.h +0 -40
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
|
|
|
217
217
|
GGML_METAL_KERNEL_TYPE_NORM,
|
|
218
218
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
|
219
219
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
|
220
|
+
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
|
|
220
221
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
|
221
222
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
|
222
223
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
@@ -526,6 +527,11 @@ enum ggml_metal_kernel_type {
|
|
|
526
527
|
GGML_METAL_KERNEL_TYPE_SIN,
|
|
527
528
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
528
529
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
530
|
+
GGML_METAL_KERNEL_TYPE_REGLU,
|
|
531
|
+
GGML_METAL_KERNEL_TYPE_GEGLU,
|
|
532
|
+
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
|
533
|
+
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
|
534
|
+
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
|
529
535
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
530
536
|
GGML_METAL_KERNEL_TYPE_MEAN,
|
|
531
537
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
@@ -1193,6 +1199,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1193
1199
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
1194
1200
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
|
1195
1201
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
|
1202
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
|
|
1196
1203
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
|
1197
1204
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
|
1198
1205
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
@@ -1502,6 +1509,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1502
1509
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
|
1503
1510
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
1504
1511
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
1512
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
|
1513
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
|
1514
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
|
1515
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
|
1516
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
|
1505
1517
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1506
1518
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
1507
1519
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
@@ -1680,6 +1692,17 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1680
1692
|
default:
|
|
1681
1693
|
return false;
|
|
1682
1694
|
}
|
|
1695
|
+
case GGML_OP_GLU:
|
|
1696
|
+
switch (ggml_get_glu_op(op)) {
|
|
1697
|
+
case GGML_GLU_OP_REGLU:
|
|
1698
|
+
case GGML_GLU_OP_GEGLU:
|
|
1699
|
+
case GGML_GLU_OP_SWIGLU:
|
|
1700
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
1701
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
1702
|
+
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
1703
|
+
default:
|
|
1704
|
+
return false;
|
|
1705
|
+
}
|
|
1683
1706
|
case GGML_OP_NONE:
|
|
1684
1707
|
case GGML_OP_RESHAPE:
|
|
1685
1708
|
case GGML_OP_VIEW:
|
|
@@ -1710,7 +1733,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1710
1733
|
case GGML_OP_MEAN:
|
|
1711
1734
|
case GGML_OP_SOFT_MAX:
|
|
1712
1735
|
case GGML_OP_GROUP_NORM:
|
|
1713
|
-
return has_simdgroup_reduction &&
|
|
1736
|
+
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
|
1714
1737
|
case GGML_OP_RMS_NORM:
|
|
1715
1738
|
case GGML_OP_L2_NORM:
|
|
1716
1739
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
|
@@ -2233,7 +2256,9 @@ static bool ggml_metal_encode_node(
|
|
|
2233
2256
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2234
2257
|
|
|
2235
2258
|
float scale;
|
|
2236
|
-
|
|
2259
|
+
float bias;
|
|
2260
|
+
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
|
2261
|
+
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
|
2237
2262
|
|
|
2238
2263
|
int64_t n = ggml_nelements(dst);
|
|
2239
2264
|
|
|
@@ -2250,6 +2275,7 @@ static bool ggml_metal_encode_node(
|
|
|
2250
2275
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2251
2276
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2252
2277
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
2278
|
+
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
|
|
2253
2279
|
|
|
2254
2280
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2255
2281
|
} break;
|
|
@@ -2419,6 +2445,68 @@ static bool ggml_metal_encode_node(
|
|
|
2419
2445
|
GGML_ABORT("fatal error");
|
|
2420
2446
|
}
|
|
2421
2447
|
} break;
|
|
2448
|
+
case GGML_OP_GLU:
|
|
2449
|
+
{
|
|
2450
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2451
|
+
|
|
2452
|
+
if (src1) {
|
|
2453
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
|
2454
|
+
}
|
|
2455
|
+
|
|
2456
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2457
|
+
|
|
2458
|
+
switch (ggml_get_glu_op(node)) {
|
|
2459
|
+
case GGML_GLU_OP_REGLU:
|
|
2460
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
|
|
2461
|
+
break;
|
|
2462
|
+
case GGML_GLU_OP_GEGLU:
|
|
2463
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
|
|
2464
|
+
break;
|
|
2465
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2466
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
|
2467
|
+
break;
|
|
2468
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2469
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
|
2470
|
+
break;
|
|
2471
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2472
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
|
|
2473
|
+
break;
|
|
2474
|
+
default:
|
|
2475
|
+
GGML_ABORT("fatal error");
|
|
2476
|
+
}
|
|
2477
|
+
|
|
2478
|
+
const int32_t swp = ((const int32_t *) dst->op_params)[1];
|
|
2479
|
+
|
|
2480
|
+
const int32_t i00 = swp ? ne0 : 0;
|
|
2481
|
+
const int32_t i10 = swp ? 0 : ne0;
|
|
2482
|
+
|
|
2483
|
+
ggml_metal_kargs_glu args = {
|
|
2484
|
+
/*.ne00 =*/ ne00,
|
|
2485
|
+
/*.nb01 =*/ nb01,
|
|
2486
|
+
/*.ne10 =*/ src1 ? ne10 : ne00,
|
|
2487
|
+
/*.nb11 =*/ src1 ? nb11 : nb01,
|
|
2488
|
+
/*.ne0 =*/ ne0,
|
|
2489
|
+
/*.nb1 =*/ nb1,
|
|
2490
|
+
/*.i00 =*/ src1 ? 0 : i00,
|
|
2491
|
+
/*.i10 =*/ src1 ? 0 : i10,
|
|
2492
|
+
};
|
|
2493
|
+
|
|
2494
|
+
[encoder setComputePipelineState:pipeline];
|
|
2495
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2496
|
+
if (src1) {
|
|
2497
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
2498
|
+
} else {
|
|
2499
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2500
|
+
}
|
|
2501
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
2502
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
2503
|
+
|
|
2504
|
+
const int64_t nrows = ggml_nrows(src0);
|
|
2505
|
+
|
|
2506
|
+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
|
|
2507
|
+
|
|
2508
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2509
|
+
} break;
|
|
2422
2510
|
case GGML_OP_SQR:
|
|
2423
2511
|
{
|
|
2424
2512
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
@@ -2573,10 +2661,7 @@ static bool ggml_metal_encode_node(
|
|
|
2573
2661
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
2574
2662
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
2575
2663
|
|
|
2576
|
-
const
|
|
2577
|
-
const int64_t nrows_y = src0->ne[1];
|
|
2578
|
-
|
|
2579
|
-
const uint32_t n_head = nrows_x/nrows_y;
|
|
2664
|
+
const uint32_t n_head = src0->ne[2];
|
|
2580
2665
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
2581
2666
|
|
|
2582
2667
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
@@ -2636,6 +2721,18 @@ static bool ggml_metal_encode_node(
|
|
|
2636
2721
|
/*.ne00 =*/ ne00,
|
|
2637
2722
|
/*.ne01 =*/ ne01,
|
|
2638
2723
|
/*.ne02 =*/ ne02,
|
|
2724
|
+
/*.nb01 =*/ nb01,
|
|
2725
|
+
/*.nb02 =*/ nb02,
|
|
2726
|
+
/*.nb03 =*/ nb03,
|
|
2727
|
+
/*.ne11 =*/ ne11,
|
|
2728
|
+
/*.ne12 =*/ ne12,
|
|
2729
|
+
/*.ne13 =*/ ne13,
|
|
2730
|
+
/*.nb11 =*/ nb11,
|
|
2731
|
+
/*.nb12 =*/ nb12,
|
|
2732
|
+
/*.nb13 =*/ nb13,
|
|
2733
|
+
/*.nb1 =*/ nb1,
|
|
2734
|
+
/*.nb2 =*/ nb2,
|
|
2735
|
+
/*.nb3 =*/ nb3,
|
|
2639
2736
|
/*.scale =*/ scale,
|
|
2640
2737
|
/*.max_bias =*/ max_bias,
|
|
2641
2738
|
/*.m0 =*/ m0,
|
|
@@ -2655,7 +2752,7 @@ static bool ggml_metal_encode_node(
|
|
|
2655
2752
|
|
|
2656
2753
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2657
2754
|
|
|
2658
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01
|
|
2755
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2659
2756
|
} break;
|
|
2660
2757
|
case GGML_OP_DIAG_MASK_INF:
|
|
2661
2758
|
{
|
|
@@ -2729,71 +2826,91 @@ static bool ggml_metal_encode_node(
|
|
|
2729
2826
|
struct ggml_tensor * src3 = node->src[3];
|
|
2730
2827
|
struct ggml_tensor * src4 = node->src[4];
|
|
2731
2828
|
struct ggml_tensor * src5 = node->src[5];
|
|
2829
|
+
struct ggml_tensor * src6 = node->src[6];
|
|
2732
2830
|
|
|
2733
2831
|
GGML_ASSERT(src3);
|
|
2734
2832
|
GGML_ASSERT(src4);
|
|
2735
2833
|
GGML_ASSERT(src5);
|
|
2834
|
+
GGML_ASSERT(src6);
|
|
2736
2835
|
|
|
2737
2836
|
size_t offs_src3 = 0;
|
|
2738
2837
|
size_t offs_src4 = 0;
|
|
2739
2838
|
size_t offs_src5 = 0;
|
|
2839
|
+
size_t offs_src6 = 0;
|
|
2740
2840
|
|
|
2741
2841
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
|
2742
2842
|
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
|
2743
2843
|
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
|
2844
|
+
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
|
2744
2845
|
|
|
2745
|
-
const int64_t ne30 = src3->ne[0];
|
|
2846
|
+
const int64_t ne30 = src3->ne[0];
|
|
2746
2847
|
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
|
2747
2848
|
|
|
2748
|
-
const uint64_t nb30 = src3->nb[0];
|
|
2849
|
+
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
|
|
2749
2850
|
const uint64_t nb31 = src3->nb[1];
|
|
2750
2851
|
|
|
2751
2852
|
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
|
2752
|
-
const int64_t ne41 = src4->ne[1];
|
|
2853
|
+
const int64_t ne41 = src4->ne[1];
|
|
2753
2854
|
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
|
2855
|
+
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
|
2754
2856
|
|
|
2755
|
-
const uint64_t nb40 = src4->nb[0];
|
|
2857
|
+
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
|
|
2756
2858
|
const uint64_t nb41 = src4->nb[1];
|
|
2757
2859
|
const uint64_t nb42 = src4->nb[2];
|
|
2860
|
+
const uint64_t nb43 = src4->nb[3];
|
|
2758
2861
|
|
|
2759
2862
|
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
|
2760
2863
|
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
|
2761
2864
|
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
|
2865
|
+
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
|
2762
2866
|
|
|
2763
|
-
const uint64_t nb50 = src5->nb[0];
|
|
2867
|
+
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
|
|
2764
2868
|
const uint64_t nb51 = src5->nb[1];
|
|
2765
2869
|
const uint64_t nb52 = src5->nb[2];
|
|
2870
|
+
const uint64_t nb53 = src5->nb[3];
|
|
2871
|
+
|
|
2872
|
+
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
|
2873
|
+
|
|
2874
|
+
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
|
2766
2875
|
|
|
2767
2876
|
const int64_t d_state = ne00;
|
|
2768
2877
|
const int64_t d_inner = ne01;
|
|
2769
|
-
const int64_t
|
|
2770
|
-
const int64_t
|
|
2878
|
+
const int64_t n_head = ne02;
|
|
2879
|
+
const int64_t n_group = ne41;
|
|
2880
|
+
const int64_t n_seq_tokens = ne12;
|
|
2881
|
+
const int64_t n_seqs = ne13;
|
|
2771
2882
|
|
|
2772
|
-
id<MTLComputePipelineState> pipeline =
|
|
2883
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2884
|
+
|
|
2885
|
+
if (ne30 == 1) {
|
|
2886
|
+
// Mamba-2
|
|
2887
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
|
|
2888
|
+
} else {
|
|
2889
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
|
2890
|
+
}
|
|
2773
2891
|
|
|
2774
2892
|
ggml_metal_kargs_ssm_scan args = {
|
|
2775
|
-
/*.d_state
|
|
2776
|
-
/*.d_inner
|
|
2893
|
+
/*.d_state =*/ d_state,
|
|
2894
|
+
/*.d_inner =*/ d_inner,
|
|
2895
|
+
/*.n_head =*/ n_head,
|
|
2896
|
+
/*.n_group =*/ n_group,
|
|
2777
2897
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
2778
|
-
/*.n_seqs
|
|
2779
|
-
/*.
|
|
2780
|
-
/*.
|
|
2781
|
-
/*.
|
|
2782
|
-
/*.
|
|
2783
|
-
/*.
|
|
2784
|
-
/*.
|
|
2785
|
-
/*.
|
|
2786
|
-
/*.
|
|
2787
|
-
/*.
|
|
2788
|
-
/*.
|
|
2789
|
-
/*.
|
|
2790
|
-
/*.
|
|
2791
|
-
/*.
|
|
2792
|
-
/*.
|
|
2793
|
-
/*.
|
|
2794
|
-
/*.nb50 =*/ nb50,
|
|
2795
|
-
/*.nb51 =*/ nb51,
|
|
2796
|
-
/*.nb52 =*/ nb52,
|
|
2898
|
+
/*.n_seqs =*/ n_seqs,
|
|
2899
|
+
/*.nb01 =*/ nb01,
|
|
2900
|
+
/*.nb02 =*/ nb02,
|
|
2901
|
+
/*.nb03 =*/ nb03,
|
|
2902
|
+
/*.nb11 =*/ nb11,
|
|
2903
|
+
/*.nb12 =*/ nb12,
|
|
2904
|
+
/*.nb13 =*/ nb13,
|
|
2905
|
+
/*.nb21 =*/ nb21,
|
|
2906
|
+
/*.nb22 =*/ nb22,
|
|
2907
|
+
/*.nb31 =*/ nb31,
|
|
2908
|
+
/*.nb41 =*/ nb41,
|
|
2909
|
+
/*.nb42 =*/ nb42,
|
|
2910
|
+
/*.nb43 =*/ nb43,
|
|
2911
|
+
/*.nb51 =*/ nb51,
|
|
2912
|
+
/*.nb52 =*/ nb52,
|
|
2913
|
+
/*.nb53 =*/ nb53,
|
|
2797
2914
|
};
|
|
2798
2915
|
|
|
2799
2916
|
[encoder setComputePipelineState:pipeline];
|
|
@@ -2803,10 +2920,17 @@ static bool ggml_metal_encode_node(
|
|
|
2803
2920
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
|
2804
2921
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
|
2805
2922
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
|
2806
|
-
[encoder setBuffer:
|
|
2807
|
-
[encoder
|
|
2923
|
+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
|
2924
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
|
2925
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
|
2808
2926
|
|
|
2809
|
-
|
|
2927
|
+
if (ne30 == 1) {
|
|
2928
|
+
// Mamba-2
|
|
2929
|
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2930
|
+
} else {
|
|
2931
|
+
GGML_ASSERT(d_inner == 1);
|
|
2932
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2933
|
+
}
|
|
2810
2934
|
} break;
|
|
2811
2935
|
case GGML_OP_RWKV_WKV6:
|
|
2812
2936
|
{
|
|
@@ -4908,7 +5032,11 @@ static bool ggml_metal_encode_node(
|
|
|
4908
5032
|
/*.nb21 =*/ nb21,
|
|
4909
5033
|
/*.nb22 =*/ nb22,
|
|
4910
5034
|
/*.nb23 =*/ nb23,
|
|
5035
|
+
/*.ne32 =*/ ne32,
|
|
5036
|
+
/*.ne33 =*/ ne33,
|
|
4911
5037
|
/*.nb31 =*/ nb31,
|
|
5038
|
+
/*.nb32 =*/ nb32,
|
|
5039
|
+
/*.nb33 =*/ nb33,
|
|
4912
5040
|
/*.ne1 =*/ ne1,
|
|
4913
5041
|
/*.ne2 =*/ ne2,
|
|
4914
5042
|
/*.scale =*/ scale,
|