@novastera-oss/llamarn 0.2.1 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +80 -14
- package/RNLlamaCpp.podspec +10 -3
- package/android/CMakeLists.txt +8 -0
- package/android/src/main/cpp/include/llama.h +62 -125
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/README.md +11 -3
- package/cpp/llama.cpp/build-xcframework.sh +1 -0
- package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/common/arg.cpp +153 -113
- package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
- package/cpp/llama.cpp/common/chat-parser.h +117 -0
- package/cpp/llama.cpp/common/chat.cpp +847 -699
- package/cpp/llama.cpp/common/chat.h +73 -6
- package/cpp/llama.cpp/common/common.cpp +50 -82
- package/cpp/llama.cpp/common/common.h +21 -17
- package/cpp/llama.cpp/common/json-partial.cpp +255 -0
- package/cpp/llama.cpp/common/json-partial.h +37 -0
- package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
- package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
- package/cpp/llama.cpp/common/regex-partial.h +56 -0
- package/cpp/llama.cpp/common/sampling.cpp +7 -8
- package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
- package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
- package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
- package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
- package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
- package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
- package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
- package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
- package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
- package/cpp/llama.cpp/include/llama.h +62 -125
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
- package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
- package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
- package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
- package/cpp/llama.cpp/src/llama-arch.h +2 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
- package/cpp/llama.cpp/src/llama-context.cpp +340 -123
- package/cpp/llama.cpp/src/llama-context.h +30 -0
- package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-cparams.h +2 -0
- package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
- package/cpp/llama.cpp/src/llama-graph.h +52 -7
- package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
- package/cpp/llama.cpp/src/llama-hparams.h +37 -5
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
- package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
- package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
- package/cpp/llama.cpp/src/llama-memory.h +4 -3
- package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
- package/cpp/llama.cpp/src/llama-model.cpp +529 -172
- package/cpp/llama.cpp/src/llama-model.h +6 -1
- package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
- package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
- package/cpp/llama.cpp/src/llama-vocab.h +6 -0
- package/cpp/llama.cpp/src/llama.cpp +14 -0
- package/cpp/rn-completion.cpp +4 -2
- package/ios/include/chat.h +73 -6
- package/ios/include/common/minja/chat-template.hpp +9 -5
- package/ios/include/common/minja/minja.hpp +69 -36
- package/ios/include/common.h +21 -17
- package/ios/include/llama.h +62 -125
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/common/stb_image.h +0 -7988
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
|
@@ -856,6 +856,7 @@ kernel void kernel_tanh(
|
|
|
856
856
|
constant float GELU_COEF_A = 0.044715f;
|
|
857
857
|
constant float GELU_QUICK_COEF = -1.702f;
|
|
858
858
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
859
|
+
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
|
859
860
|
|
|
860
861
|
kernel void kernel_gelu(
|
|
861
862
|
device const float * src0,
|
|
@@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4(
|
|
|
897
898
|
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
|
898
899
|
}
|
|
899
900
|
|
|
901
|
+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
|
902
|
+
// ref: https://www.johndcook.com/blog/python_erf/
|
|
903
|
+
constant float p_erf = 0.3275911f;
|
|
904
|
+
constant float a1_erf = 0.254829592f;
|
|
905
|
+
constant float a2_erf = -0.284496736f;
|
|
906
|
+
constant float a3_erf = 1.421413741f;
|
|
907
|
+
constant float a4_erf = -1.453152027f;
|
|
908
|
+
constant float a5_erf = 1.061405429f;
|
|
909
|
+
|
|
910
|
+
template<typename T>
|
|
911
|
+
T erf_approx(T x) {
|
|
912
|
+
T sign_x = sign(x);
|
|
913
|
+
x = fabs(x);
|
|
914
|
+
T t = 1.0f / (1.0f + p_erf * x);
|
|
915
|
+
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
|
916
|
+
return sign_x * y;
|
|
917
|
+
}
|
|
918
|
+
|
|
919
|
+
kernel void kernel_gelu_erf(
|
|
920
|
+
device const float * src0,
|
|
921
|
+
device float * dst,
|
|
922
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
923
|
+
device const float & x = src0[tpig];
|
|
924
|
+
|
|
925
|
+
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
|
|
926
|
+
}
|
|
927
|
+
|
|
928
|
+
kernel void kernel_gelu_erf_4(
|
|
929
|
+
device const float4 * src0,
|
|
930
|
+
device float4 * dst,
|
|
931
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
932
|
+
device const float4 & x = src0[tpig];
|
|
933
|
+
|
|
934
|
+
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
|
|
935
|
+
}
|
|
936
|
+
|
|
900
937
|
kernel void kernel_silu(
|
|
901
938
|
device const float * src0,
|
|
902
939
|
device float * dst,
|
|
@@ -2713,8 +2750,148 @@ kernel void kernel_rope_neox(
|
|
|
2713
2750
|
}
|
|
2714
2751
|
}
|
|
2715
2752
|
|
|
2753
|
+
template<typename T>
|
|
2754
|
+
kernel void kernel_rope_multi(
|
|
2755
|
+
constant ggml_metal_kargs_rope & args,
|
|
2756
|
+
device const char * src0,
|
|
2757
|
+
device const char * src1,
|
|
2758
|
+
device const char * src2,
|
|
2759
|
+
device char * dst,
|
|
2760
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
2761
|
+
ushort3 tptg [[threads_per_threadgroup]],
|
|
2762
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
2763
|
+
const int i3 = tgpig[2];
|
|
2764
|
+
const int i2 = tgpig[1];
|
|
2765
|
+
const int i1 = tgpig[0];
|
|
2766
|
+
|
|
2767
|
+
float corr_dims[2];
|
|
2768
|
+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
|
2769
|
+
|
|
2770
|
+
device const int32_t * pos = (device const int32_t *) src1;
|
|
2771
|
+
|
|
2772
|
+
const float inv_ndims = -1.f/args.n_dims;
|
|
2773
|
+
|
|
2774
|
+
float cos_theta;
|
|
2775
|
+
float sin_theta;
|
|
2776
|
+
|
|
2777
|
+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
|
2778
|
+
if (i0 < args.n_dims) {
|
|
2779
|
+
const int ic = i0/2;
|
|
2780
|
+
|
|
2781
|
+
// mrope theta calculations
|
|
2782
|
+
// note: the rest is the same as kernel_rope_neox
|
|
2783
|
+
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
|
|
2784
|
+
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
|
|
2785
|
+
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
|
2786
|
+
const int sector = ic % sect_dims;
|
|
2787
|
+
|
|
2788
|
+
float theta_base;
|
|
2789
|
+
if (sector < args.sect_0) {
|
|
2790
|
+
theta_base = (float) pos[i2];
|
|
2791
|
+
} else if (sector < sec_w01) {
|
|
2792
|
+
theta_base = (float) pos[i2 + args.ne02];
|
|
2793
|
+
} else if (sector < sec_w012) {
|
|
2794
|
+
theta_base = (float) pos[i2 + args.ne02 * 2];
|
|
2795
|
+
} else {
|
|
2796
|
+
theta_base = (float) pos[i2 + args.ne02 * 3];
|
|
2797
|
+
}
|
|
2798
|
+
// end of mrope
|
|
2799
|
+
|
|
2800
|
+
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
|
2801
|
+
|
|
2802
|
+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
|
2803
|
+
|
|
2804
|
+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
|
2805
|
+
|
|
2806
|
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
|
2807
|
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
|
2808
|
+
|
|
2809
|
+
const float x0 = src[0];
|
|
2810
|
+
const float x1 = src[args.n_dims/2];
|
|
2811
|
+
|
|
2812
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
2813
|
+
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
2814
|
+
} else {
|
|
2815
|
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
|
2816
|
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
2817
|
+
|
|
2818
|
+
dst_data[0] = src[0];
|
|
2819
|
+
dst_data[1] = src[1];
|
|
2820
|
+
}
|
|
2821
|
+
}
|
|
2822
|
+
}
|
|
2823
|
+
|
|
2824
|
+
template<typename T>
|
|
2825
|
+
kernel void kernel_rope_vision(
|
|
2826
|
+
constant ggml_metal_kargs_rope & args,
|
|
2827
|
+
device const char * src0,
|
|
2828
|
+
device const char * src1,
|
|
2829
|
+
device const char * src2,
|
|
2830
|
+
device char * dst,
|
|
2831
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
2832
|
+
ushort3 tptg [[threads_per_threadgroup]],
|
|
2833
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
2834
|
+
const int i3 = tgpig[2];
|
|
2835
|
+
const int i2 = tgpig[1];
|
|
2836
|
+
const int i1 = tgpig[0];
|
|
2837
|
+
|
|
2838
|
+
float corr_dims[2];
|
|
2839
|
+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
|
2840
|
+
|
|
2841
|
+
device const int32_t * pos = (device const int32_t *) src1;
|
|
2842
|
+
|
|
2843
|
+
const float inv_ndims = -1.f/args.n_dims;
|
|
2844
|
+
|
|
2845
|
+
float cos_theta;
|
|
2846
|
+
float sin_theta;
|
|
2847
|
+
|
|
2848
|
+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
|
2849
|
+
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
|
|
2850
|
+
const int ic = i0/2;
|
|
2851
|
+
|
|
2852
|
+
// mrope theta calculations (only support 2 dimensions)
|
|
2853
|
+
const int sect_dims = args.sect_0 + args.sect_1;
|
|
2854
|
+
const int sector = ic % sect_dims;
|
|
2855
|
+
|
|
2856
|
+
float p;
|
|
2857
|
+
float theta_base;
|
|
2858
|
+
if (sector < args.sect_1) {
|
|
2859
|
+
p = (float) sector;
|
|
2860
|
+
theta_base = (float) pos[i2];
|
|
2861
|
+
} else {
|
|
2862
|
+
p = (float) sector - args.sect_0;
|
|
2863
|
+
theta_base = (float) pos[i2 + args.ne02];
|
|
2864
|
+
}
|
|
2865
|
+
|
|
2866
|
+
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
|
2867
|
+
// end of mrope
|
|
2868
|
+
|
|
2869
|
+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
|
2870
|
+
|
|
2871
|
+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
|
2872
|
+
|
|
2873
|
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
|
2874
|
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
|
2875
|
+
|
|
2876
|
+
const float x0 = src[0];
|
|
2877
|
+
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
|
|
2878
|
+
|
|
2879
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
2880
|
+
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
|
|
2881
|
+
} else {
|
|
2882
|
+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
|
2883
|
+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
2884
|
+
|
|
2885
|
+
dst_data[0] = src[0];
|
|
2886
|
+
dst_data[1] = src[1];
|
|
2887
|
+
}
|
|
2888
|
+
}
|
|
2889
|
+
}
|
|
2890
|
+
|
|
2716
2891
|
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
|
2717
2892
|
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
|
2893
|
+
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
|
2894
|
+
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
|
2718
2895
|
|
|
2719
2896
|
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
|
2720
2897
|
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
|
@@ -2722,6 +2899,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
|
|
2722
2899
|
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
|
2723
2900
|
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
|
2724
2901
|
|
|
2902
|
+
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
|
2903
|
+
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
|
2904
|
+
|
|
2905
|
+
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
|
2906
|
+
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
|
2907
|
+
|
|
2725
2908
|
typedef void (im2col_t)(
|
|
2726
2909
|
device const float * x,
|
|
2727
2910
|
device char * dst,
|
|
@@ -3109,7 +3292,7 @@ template<
|
|
|
3109
3292
|
typename kd4x4_t, // key type in device memory
|
|
3110
3293
|
short nl_k,
|
|
3111
3294
|
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
|
|
3112
|
-
typename vd4x4_t, //
|
|
3295
|
+
typename vd4x4_t, // value type in device memory
|
|
3113
3296
|
short nl_v,
|
|
3114
3297
|
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
|
3115
3298
|
short DK, // K head size
|
|
@@ -3630,7 +3813,7 @@ template<
|
|
|
3630
3813
|
typename kd4_t, // key type in device memory
|
|
3631
3814
|
short nl_k,
|
|
3632
3815
|
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
|
|
3633
|
-
typename vd4_t, //
|
|
3816
|
+
typename vd4_t, // value type in device memory
|
|
3634
3817
|
short nl_v,
|
|
3635
3818
|
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
|
3636
3819
|
short DK, // K head size
|
|
@@ -3741,6 +3924,11 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
3741
3924
|
sm[tiisg] = pm[ic + tiisg];
|
|
3742
3925
|
}
|
|
3743
3926
|
|
|
3927
|
+
// skip -INF blocks
|
|
3928
|
+
if (simd_max(sm[tiisg]) == -INFINITY) {
|
|
3929
|
+
continue;
|
|
3930
|
+
}
|
|
3931
|
+
|
|
3744
3932
|
// Q*K^T
|
|
3745
3933
|
{
|
|
3746
3934
|
// each simdgroup processes 1 query and NE (NW/NL) head elements
|
|
@@ -3973,6 +4161,16 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
3973
4161
|
|
|
3974
4162
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
|
3975
4163
|
|
|
4164
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
|
|
4165
|
+
#if defined(GGML_METAL_USE_BF16)
|
|
4166
|
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
|
|
4167
|
+
#endif
|
|
4168
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 8>;
|
|
4169
|
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 8>;
|
|
4170
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 8>;
|
|
4171
|
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 8>;
|
|
4172
|
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 8>;
|
|
4173
|
+
|
|
3976
4174
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
|
3977
4175
|
#if defined(GGML_METAL_USE_BF16)
|
|
3978
4176
|
template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
|
@@ -27,12 +27,15 @@ if (MUSAToolkit_FOUND)
|
|
|
27
27
|
|
|
28
28
|
file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
|
|
29
29
|
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
|
|
30
|
+
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
|
|
30
31
|
|
|
31
32
|
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
|
|
32
33
|
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
|
33
34
|
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
|
34
35
|
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
|
35
36
|
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
|
37
|
+
file(GLOB SRCS "../ggml-musa/*.cu")
|
|
38
|
+
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
|
36
39
|
|
|
37
40
|
if (GGML_CUDA_FA_ALL_QUANTS)
|
|
38
41
|
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
|
|
@@ -62,7 +65,9 @@ if (MUSAToolkit_FOUND)
|
|
|
62
65
|
)
|
|
63
66
|
|
|
64
67
|
# TODO: do not use CUDA definitions for MUSA
|
|
65
|
-
|
|
68
|
+
if (NOT GGML_BACKEND_DL)
|
|
69
|
+
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
|
|
70
|
+
endif()
|
|
66
71
|
|
|
67
72
|
add_compile_definitions(GGML_USE_MUSA)
|
|
68
73
|
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
|
@@ -92,9 +97,10 @@ if (MUSAToolkit_FOUND)
|
|
|
92
97
|
endif()
|
|
93
98
|
|
|
94
99
|
if (GGML_STATIC)
|
|
100
|
+
# TODO: mudnn has not provided static libraries yet
|
|
95
101
|
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
|
|
96
102
|
else()
|
|
97
|
-
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
|
|
103
|
+
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
|
|
98
104
|
endif()
|
|
99
105
|
|
|
100
106
|
if (GGML_CUDA_NO_VMM)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
#include <mutex>
|
|
2
|
+
#include <mudnn.h>
|
|
3
|
+
|
|
4
|
+
#include "mudnn.cuh"
|
|
5
|
+
|
|
6
|
+
namespace mudnn = musa::dnn;
|
|
7
|
+
|
|
8
|
+
// Returns a human-readable error string for mudnn::Status
|
|
9
|
+
const char* mudnnGetErrorString(mudnn::Status err) {
|
|
10
|
+
switch (err) {
|
|
11
|
+
case mudnn::Status::SUCCESS:
|
|
12
|
+
return "Success";
|
|
13
|
+
case mudnn::Status::INVALID_PARAMETER:
|
|
14
|
+
return "Invalid parameter";
|
|
15
|
+
case mudnn::Status::NOT_INITIALIZED:
|
|
16
|
+
return "Not initialized";
|
|
17
|
+
case mudnn::Status::ALLOC_FAILED:
|
|
18
|
+
return "Allocation failed";
|
|
19
|
+
case mudnn::Status::NOT_SUPPORTED:
|
|
20
|
+
return "Not supported";
|
|
21
|
+
case mudnn::Status::INTERNAL_ERROR:
|
|
22
|
+
return "Internal error";
|
|
23
|
+
case mudnn::Status::ARCH_MISMATCH:
|
|
24
|
+
return "Architecture mismatch";
|
|
25
|
+
case mudnn::Status::EXECUTION_FAILED:
|
|
26
|
+
return "Execution failed";
|
|
27
|
+
default:
|
|
28
|
+
return "Unknown mudnn status";
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
// Error checking macro for MUDNN calls
|
|
33
|
+
#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
|
|
34
|
+
|
|
35
|
+
namespace {
|
|
36
|
+
// Thread-safe cache for mudnn::Handle objects per device
|
|
37
|
+
std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
|
|
38
|
+
std::mutex handle_cache_mutex;
|
|
39
|
+
|
|
40
|
+
mudnn::Handle* get_cached_handle(int device_id) {
|
|
41
|
+
std::lock_guard<std::mutex> lock(handle_cache_mutex);
|
|
42
|
+
auto it = handle_cache.find(device_id);
|
|
43
|
+
if (it != handle_cache.end()) {
|
|
44
|
+
return it->second.get();
|
|
45
|
+
}
|
|
46
|
+
auto handle = std::make_unique<mudnn::Handle>(device_id);
|
|
47
|
+
mudnn::Handle* handle_ptr = handle.get();
|
|
48
|
+
handle_cache[device_id] = std::move(handle);
|
|
49
|
+
return handle_ptr;
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// Extracts dimensions and strides from a ggml_tensor
|
|
54
|
+
int get_ggml_dims_and_strides(const ggml_tensor* tensor,
|
|
55
|
+
std::vector<int64_t>& dims,
|
|
56
|
+
std::vector<int64_t>& strides) {
|
|
57
|
+
const int ndims = ggml_n_dims(tensor);
|
|
58
|
+
const size_t element_size = ggml_element_size(tensor);
|
|
59
|
+
|
|
60
|
+
dims.resize(ndims);
|
|
61
|
+
strides.resize(ndims);
|
|
62
|
+
|
|
63
|
+
for (int i = 0; i < ndims; ++i) {
|
|
64
|
+
dims[i] = tensor->ne[i];
|
|
65
|
+
strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
|
|
66
|
+
}
|
|
67
|
+
return ndims;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// Converts ggml_type to mudnn::Tensor::Type
|
|
71
|
+
mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
|
|
72
|
+
switch (type) {
|
|
73
|
+
case GGML_TYPE_F32:
|
|
74
|
+
return mudnn::Tensor::Type::FLOAT;
|
|
75
|
+
case GGML_TYPE_F16:
|
|
76
|
+
return mudnn::Tensor::Type::HALF;
|
|
77
|
+
|
|
78
|
+
// TODO: Add support for other types
|
|
79
|
+
|
|
80
|
+
default:
|
|
81
|
+
MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
return mudnn::Tensor::Type::FLOAT; // Default fallback
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
// Asynchronous memory copy using mudnn::Unary::IDENTITY
|
|
88
|
+
musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
|
|
89
|
+
mudnn::Tensor tensor_dst, tensor_src;
|
|
90
|
+
|
|
91
|
+
MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
|
|
92
|
+
MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
|
|
93
|
+
|
|
94
|
+
std::vector<int64_t> dims, strides;
|
|
95
|
+
const int ndims = get_ggml_dims_and_strides(src, dims, strides);
|
|
96
|
+
|
|
97
|
+
MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
|
|
98
|
+
MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
|
|
99
|
+
MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
|
|
100
|
+
MUDNN_CHECK(tensor_src.SetAddr(src->data));
|
|
101
|
+
|
|
102
|
+
mudnn::Unary op;
|
|
103
|
+
MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
|
|
104
|
+
MUDNN_CHECK(op.SetAlpha(0.0f));
|
|
105
|
+
MUDNN_CHECK(op.SetBeta(0.0f));
|
|
106
|
+
|
|
107
|
+
mudnn::Handle* handle = get_cached_handle(ctx.device);
|
|
108
|
+
MUDNN_CHECK(handle->SetStream(ctx.stream()));
|
|
109
|
+
MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
|
|
110
|
+
|
|
111
|
+
return musaSuccess;
|
|
112
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "../include/ggml.h"
|
|
4
|
+
#include "../ggml-cuda/common.cuh"
|
|
5
|
+
|
|
6
|
+
// Asynchronously copies data from src tensor to dst tensor using the provided context.
|
|
7
|
+
// Returns a musaError_t indicating success or failure.
|
|
8
|
+
musaError_t mudnnMemcpyAsync(
|
|
9
|
+
ggml_backend_cuda_context &ctx,
|
|
10
|
+
const ggml_tensor *dst,
|
|
11
|
+
const ggml_tensor *src
|
|
12
|
+
);
|
|
@@ -55,14 +55,17 @@ endfunction()
|
|
|
55
55
|
|
|
56
56
|
set(GGML_OPENCL_KERNELS
|
|
57
57
|
add
|
|
58
|
+
argsort
|
|
58
59
|
clamp
|
|
59
60
|
cpy
|
|
60
61
|
cvt
|
|
61
62
|
diag_mask_inf
|
|
63
|
+
div
|
|
62
64
|
gelu
|
|
63
65
|
gemv_noshuffle_general
|
|
64
66
|
gemv_noshuffle
|
|
65
67
|
get_rows
|
|
68
|
+
group_norm
|
|
66
69
|
im2col_f32
|
|
67
70
|
im2col_f16
|
|
68
71
|
mul_mat_Ab_Bi_8x4
|
|
@@ -83,11 +86,14 @@ set(GGML_OPENCL_KERNELS
|
|
|
83
86
|
rms_norm
|
|
84
87
|
rope
|
|
85
88
|
scale
|
|
89
|
+
sigmoid
|
|
86
90
|
silu
|
|
87
91
|
softmax_4_f32
|
|
88
92
|
softmax_4_f16
|
|
89
93
|
softmax_f32
|
|
90
94
|
softmax_f16
|
|
95
|
+
sub
|
|
96
|
+
sum_rows
|
|
91
97
|
transpose
|
|
92
98
|
)
|
|
93
99
|
|