@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -43,6 +43,7 @@
|
|
|
43
43
|
#include "ggml-cuda/upscale.cuh"
|
|
44
44
|
#include "ggml-cuda/wkv.cuh"
|
|
45
45
|
#include "ggml-cuda/gla.cuh"
|
|
46
|
+
#include "ggml-cuda/set-rows.cuh"
|
|
46
47
|
#include "ggml.h"
|
|
47
48
|
|
|
48
49
|
#include <algorithm>
|
|
@@ -54,6 +55,7 @@
|
|
|
54
55
|
#include <cstddef>
|
|
55
56
|
#include <cstdint>
|
|
56
57
|
#include <float.h>
|
|
58
|
+
#include <initializer_list>
|
|
57
59
|
#include <limits>
|
|
58
60
|
#include <map>
|
|
59
61
|
#include <memory>
|
|
@@ -1749,7 +1751,7 @@ static void ggml_cuda_op_mul_mat(
|
|
|
1749
1751
|
}
|
|
1750
1752
|
|
|
1751
1753
|
static __global__ void k_compute_batched_ptrs(
|
|
1752
|
-
const
|
|
1754
|
+
const void * src0_as_f16, const void * src1_as_f16, char * dst,
|
|
1753
1755
|
const void ** ptrs_src, void ** ptrs_dst,
|
|
1754
1756
|
int64_t ne12, int64_t ne13,
|
|
1755
1757
|
int64_t ne23,
|
|
@@ -1772,83 +1774,131 @@ static __global__ void k_compute_batched_ptrs(
|
|
|
1772
1774
|
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
|
1773
1775
|
}
|
|
1774
1776
|
|
|
1775
|
-
|
|
1777
|
+
// Type traits for mapping ggml types to CUDA/cuBLAS types
|
|
1778
|
+
template<ggml_type T>
|
|
1779
|
+
struct batched_mul_mat_traits;
|
|
1780
|
+
|
|
1781
|
+
template<>
|
|
1782
|
+
struct batched_mul_mat_traits<GGML_TYPE_F32> {
|
|
1783
|
+
using cuda_type = float;
|
|
1784
|
+
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
|
|
1785
|
+
static inline const cudaDataType_t data_type = CUDA_R_32F;
|
|
1786
|
+
static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
|
|
1787
|
+
static inline const float alpha = 1.0f;
|
|
1788
|
+
static inline const float beta = 0.0f;
|
|
1789
|
+
static inline const void* get_alpha() { static const float val = alpha; return &val; }
|
|
1790
|
+
static inline const void* get_beta() { static const float val = beta; return &val; }
|
|
1791
|
+
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
|
|
1792
|
+
};
|
|
1793
|
+
|
|
1794
|
+
template<>
|
|
1795
|
+
struct batched_mul_mat_traits<GGML_TYPE_BF16> {
|
|
1796
|
+
using cuda_type = nv_bfloat16;
|
|
1797
|
+
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
|
|
1798
|
+
static inline const cudaDataType_t data_type = CUDA_R_16BF;
|
|
1799
|
+
static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
|
|
1800
|
+
static inline const float alpha = 1.0f;
|
|
1801
|
+
static inline const float beta = 0.0f;
|
|
1802
|
+
static inline const void* get_alpha() { static const float val = alpha; return &val; }
|
|
1803
|
+
static inline const void* get_beta() { static const float val = beta; return &val; }
|
|
1804
|
+
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
|
|
1805
|
+
};
|
|
1806
|
+
|
|
1807
|
+
template<>
|
|
1808
|
+
struct batched_mul_mat_traits<GGML_TYPE_F16> {
|
|
1809
|
+
using cuda_type = half;
|
|
1810
|
+
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
|
|
1811
|
+
static inline const cudaDataType_t data_type = CUDA_R_16F;
|
|
1812
|
+
static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
|
|
1813
|
+
static inline const half alpha = 1.0;
|
|
1814
|
+
static inline const half beta = 0.0;
|
|
1815
|
+
static inline const void* get_alpha() { static const half val = alpha; return &val; }
|
|
1816
|
+
static inline const void* get_beta() { static const half val = beta; return &val; }
|
|
1817
|
+
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
|
|
1818
|
+
};
|
|
1819
|
+
|
|
1820
|
+
template<ggml_type src0_type>
|
|
1821
|
+
static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
1822
|
+
using traits = batched_mul_mat_traits<src0_type>;
|
|
1823
|
+
using cuda_t = typename traits::cuda_type;
|
|
1824
|
+
|
|
1776
1825
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
1777
1826
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
1778
|
-
|
|
1779
1827
|
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
|
|
1780
|
-
GGML_ASSERT(src0->type ==
|
|
1828
|
+
GGML_ASSERT(src0->type == src0_type);
|
|
1829
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
1781
1830
|
|
|
1782
1831
|
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
|
|
1783
1832
|
// As long as dst is contiguous this does not matter though.
|
|
1784
|
-
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
1785
1833
|
|
|
1786
1834
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
1787
1835
|
|
|
1788
1836
|
const int64_t ne_dst = ggml_nelements(dst);
|
|
1789
|
-
|
|
1790
1837
|
cudaStream_t main_stream = ctx.stream();
|
|
1791
|
-
|
|
1792
1838
|
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
|
|
1793
1839
|
|
|
1794
|
-
const half * src0_f16 = (const half *) src0->data;
|
|
1795
1840
|
float * dst_ddf = (float *) dst->data;
|
|
1796
|
-
|
|
1797
|
-
const half * src1_f16 = (const half *) src1->data;
|
|
1798
1841
|
const size_t ts_src1 = ggml_type_size(src1->type);
|
|
1799
1842
|
GGML_ASSERT(nb10 == ts_src1);
|
|
1800
1843
|
int64_t s11 = nb11 / ts_src1;
|
|
1801
1844
|
int64_t s12 = nb12 / ts_src1;
|
|
1802
1845
|
int64_t s13 = nb13 / ts_src1;
|
|
1803
|
-
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
|
|
1804
1846
|
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
GGML_ASSERT(to_fp16_cuda != nullptr);
|
|
1847
|
+
const cuda_t * src0_ptr = nullptr;
|
|
1848
|
+
const cuda_t * src1_ptr = nullptr;
|
|
1849
|
+
|
|
1850
|
+
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
|
|
1851
|
+
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
|
|
1811
1852
|
|
|
1812
|
-
|
|
1853
|
+
// Handle src0
|
|
1854
|
+
src0_ptr = (const cuda_t *) src0->data;
|
|
1813
1855
|
|
|
1814
|
-
|
|
1856
|
+
// Handle src1 - convert if necessary
|
|
1857
|
+
if (src1->type == src0_type) {
|
|
1858
|
+
src1_ptr = (const cuda_t *) src1->data;
|
|
1859
|
+
} else {
|
|
1860
|
+
// Convert src1 to target type using traits conversion functions
|
|
1861
|
+
const int64_t ne_src1 = ggml_nelements(src1);
|
|
1862
|
+
src1_alloc.alloc(ne_src1);
|
|
1863
|
+
|
|
1864
|
+
const auto convert_func = traits::get_nc_converter(src1->type);
|
|
1865
|
+
GGML_ASSERT(convert_func != nullptr);
|
|
1866
|
+
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
|
|
1867
|
+
src1_ptr = src1_alloc.get();
|
|
1815
1868
|
s11 = ne10;
|
|
1816
1869
|
s12 = ne11*s11;
|
|
1817
1870
|
s13 = ne12*s12;
|
|
1818
1871
|
}
|
|
1819
1872
|
|
|
1820
|
-
|
|
1873
|
+
// Setup destination buffer
|
|
1874
|
+
ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
|
|
1821
1875
|
char * dst_t;
|
|
1822
|
-
|
|
1823
|
-
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
|
1824
|
-
cudaDataType_t cu_data_type = CUDA_R_16F;
|
|
1825
|
-
|
|
1826
|
-
// dst strides
|
|
1827
1876
|
size_t nbd2 = dst->nb[2];
|
|
1828
1877
|
size_t nbd3 = dst->nb[3];
|
|
1829
1878
|
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1879
|
+
cublasComputeType_t cu_compute_type = traits::compute_type;
|
|
1880
|
+
cudaDataType_t cu_data_type = traits::data_type;
|
|
1881
|
+
cudaDataType_t cu_data_type_a = traits::data_type;
|
|
1882
|
+
cudaDataType_t cu_data_type_b = traits::data_type;
|
|
1883
|
+
const void * alpha = traits::get_alpha();
|
|
1884
|
+
const void * beta = traits::get_beta();
|
|
1833
1885
|
const float alpha_f32 = 1.0f;
|
|
1834
|
-
const float beta_f32
|
|
1835
|
-
|
|
1836
|
-
const void * alpha = &alpha_f16;
|
|
1837
|
-
const void * beta = &beta_f16;
|
|
1886
|
+
const float beta_f32 = 0.0f;
|
|
1838
1887
|
|
|
1839
1888
|
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
|
|
1889
|
+
if constexpr (src0_type == GGML_TYPE_F32) {
|
|
1890
|
+
dst_t = (char *) dst_ddf; // Direct F32 output
|
|
1891
|
+
} else {
|
|
1892
|
+
dst_t = (char *) dst_temp.alloc(ne_dst);
|
|
1893
|
+
nbd2 /= sizeof(float) / sizeof(cuda_t);
|
|
1894
|
+
nbd3 /= sizeof(float) / sizeof(cuda_t);
|
|
1895
|
+
}
|
|
1844
1896
|
} else {
|
|
1845
1897
|
dst_t = (char *) dst_ddf;
|
|
1846
|
-
|
|
1847
1898
|
cu_compute_type = CUBLAS_COMPUTE_32F;
|
|
1848
|
-
cu_data_type
|
|
1849
|
-
|
|
1899
|
+
cu_data_type = CUDA_R_32F;
|
|
1850
1900
|
alpha = &alpha_f32;
|
|
1851
|
-
beta
|
|
1901
|
+
beta = &beta_f32;
|
|
1852
1902
|
}
|
|
1853
1903
|
|
|
1854
1904
|
int id = ggml_cuda_get_device();
|
|
@@ -1856,7 +1906,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
|
1856
1906
|
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
|
1857
1907
|
cu_compute_type = CUBLAS_COMPUTE_32F;
|
|
1858
1908
|
alpha = &alpha_f32;
|
|
1859
|
-
beta
|
|
1909
|
+
beta = &beta_f32;
|
|
1860
1910
|
}
|
|
1861
1911
|
|
|
1862
1912
|
GGML_ASSERT(ne12 % ne02 == 0);
|
|
@@ -1866,35 +1916,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
|
1866
1916
|
const int64_t r2 = ne12/ne02;
|
|
1867
1917
|
const int64_t r3 = ne13/ne03;
|
|
1868
1918
|
|
|
1869
|
-
#if 0
|
|
1870
|
-
// use cublasGemmEx
|
|
1871
|
-
{
|
|
1872
|
-
for (int i13 = 0; i13 < ne13; ++i13) {
|
|
1873
|
-
for (int i12 = 0; i12 < ne12; ++i12) {
|
|
1874
|
-
int i03 = i13 / r3;
|
|
1875
|
-
int i02 = i12 / r2;
|
|
1876
|
-
|
|
1877
|
-
CUBLAS_CHECK(
|
|
1878
|
-
cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
|
1879
|
-
ne01, ne11, ne10,
|
|
1880
|
-
alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
|
|
1881
|
-
src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
|
|
1882
|
-
beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
|
|
1883
|
-
cu_compute_type,
|
|
1884
|
-
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
1885
|
-
}
|
|
1886
|
-
}
|
|
1887
|
-
}
|
|
1888
|
-
#else
|
|
1889
1919
|
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
|
1890
1920
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
|
1891
1921
|
// use cublasGemmStridedBatchedEx
|
|
1892
1922
|
CUBLAS_CHECK(
|
|
1893
1923
|
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
|
1894
1924
|
ne01, ne11, ne10,
|
|
1895
|
-
alpha,
|
|
1896
|
-
|
|
1897
|
-
beta, dst_t, cu_data_type,
|
|
1925
|
+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
|
|
1926
|
+
src1_ptr, cu_data_type_b, s11, s12, // strideB
|
|
1927
|
+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
|
|
1898
1928
|
ne12*ne13,
|
|
1899
1929
|
cu_compute_type,
|
|
1900
1930
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
@@ -1905,34 +1935,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
|
1905
1935
|
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
|
1906
1936
|
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
|
1907
1937
|
|
|
1938
|
+
size_t src1_stride_size = sizeof(cuda_t);
|
|
1939
|
+
|
|
1908
1940
|
dim3 block_dims(ne13, ne12);
|
|
1909
1941
|
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
|
1910
|
-
|
|
1942
|
+
src0_ptr, src1_ptr, dst_t,
|
|
1911
1943
|
ptrs_src.get(), ptrs_dst.get(),
|
|
1912
1944
|
ne12, ne13,
|
|
1913
1945
|
ne23,
|
|
1914
1946
|
nb02, nb03,
|
|
1915
|
-
src1->type ==
|
|
1916
|
-
src1->type ==
|
|
1947
|
+
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
|
|
1948
|
+
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
|
|
1917
1949
|
nbd2, nbd3,
|
|
1918
1950
|
r2, r3);
|
|
1951
|
+
|
|
1919
1952
|
CUDA_CHECK(cudaGetLastError());
|
|
1920
1953
|
|
|
1921
1954
|
CUBLAS_CHECK(
|
|
1922
1955
|
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
|
1923
1956
|
ne01, ne11, ne10,
|
|
1924
|
-
alpha, (const void **) (ptrs_src.get() + 0*ne23),
|
|
1925
|
-
(const void **) (ptrs_src.get() + 1*ne23),
|
|
1926
|
-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type,
|
|
1957
|
+
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
|
|
1958
|
+
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
|
|
1959
|
+
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
|
|
1927
1960
|
ne23,
|
|
1928
1961
|
cu_compute_type,
|
|
1929
1962
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
1930
1963
|
}
|
|
1931
|
-
#endif
|
|
1932
1964
|
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
to_fp32_cuda
|
|
1965
|
+
// Convert output back to F32 if needed
|
|
1966
|
+
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
|
|
1967
|
+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
|
|
1968
|
+
to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
|
|
1969
|
+
}
|
|
1970
|
+
}
|
|
1971
|
+
|
|
1972
|
+
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
1973
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
|
|
1974
|
+
|
|
1975
|
+
switch (src0->type) {
|
|
1976
|
+
case GGML_TYPE_F32:
|
|
1977
|
+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
|
|
1978
|
+
break;
|
|
1979
|
+
case GGML_TYPE_BF16:
|
|
1980
|
+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
|
|
1981
|
+
break;
|
|
1982
|
+
case GGML_TYPE_F16:
|
|
1983
|
+
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
|
|
1984
|
+
break;
|
|
1985
|
+
default:
|
|
1986
|
+
GGML_ABORT("Unsupported type");
|
|
1936
1987
|
}
|
|
1937
1988
|
}
|
|
1938
1989
|
|
|
@@ -1984,6 +2035,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
|
1984
2035
|
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
|
1985
2036
|
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
|
1986
2037
|
|
|
2038
|
+
//TODO update for generic tensor parallelism
|
|
2039
|
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
2040
|
+
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
|
|
2041
|
+
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
|
|
2042
|
+
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
|
|
2043
|
+
|
|
1987
2044
|
if (!split && use_mul_mat_vec) {
|
|
1988
2045
|
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
|
1989
2046
|
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
|
@@ -1992,8 +2049,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
|
1992
2049
|
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
|
1993
2050
|
} else if (!split && use_mul_mat_q) {
|
|
1994
2051
|
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
|
|
1995
|
-
} else if (!split &&
|
|
1996
|
-
|
|
2052
|
+
} else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
|
|
2053
|
+
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
1997
2054
|
// general KQ + KQV multi-batch without FlashAttention
|
|
1998
2055
|
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
|
1999
2056
|
} else if (use_mul_mat_vec) {
|
|
@@ -2175,6 +2232,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
|
2175
2232
|
case GGML_OP_GET_ROWS_BACK:
|
|
2176
2233
|
ggml_cuda_op_get_rows_back(ctx, dst);
|
|
2177
2234
|
break;
|
|
2235
|
+
case GGML_OP_SET_ROWS:
|
|
2236
|
+
ggml_cuda_op_set_rows(ctx, dst);
|
|
2237
|
+
break;
|
|
2178
2238
|
case GGML_OP_DUP:
|
|
2179
2239
|
ggml_cuda_dup(ctx, dst);
|
|
2180
2240
|
break;
|
|
@@ -2244,6 +2304,30 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
|
2244
2304
|
case GGML_UNARY_OP_EXP:
|
|
2245
2305
|
ggml_cuda_op_exp(ctx, dst);
|
|
2246
2306
|
break;
|
|
2307
|
+
case GGML_UNARY_OP_ELU:
|
|
2308
|
+
ggml_cuda_op_elu(ctx, dst);
|
|
2309
|
+
break;
|
|
2310
|
+
default:
|
|
2311
|
+
return false;
|
|
2312
|
+
}
|
|
2313
|
+
break;
|
|
2314
|
+
case GGML_OP_GLU:
|
|
2315
|
+
switch (ggml_get_glu_op(dst)) {
|
|
2316
|
+
case GGML_GLU_OP_REGLU:
|
|
2317
|
+
ggml_cuda_op_reglu(ctx, dst);
|
|
2318
|
+
break;
|
|
2319
|
+
case GGML_GLU_OP_GEGLU:
|
|
2320
|
+
ggml_cuda_op_geglu(ctx, dst);
|
|
2321
|
+
break;
|
|
2322
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2323
|
+
ggml_cuda_op_swiglu(ctx, dst);
|
|
2324
|
+
break;
|
|
2325
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2326
|
+
ggml_cuda_op_geglu_erf(ctx, dst);
|
|
2327
|
+
break;
|
|
2328
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2329
|
+
ggml_cuda_op_geglu_quick(ctx, dst);
|
|
2330
|
+
break;
|
|
2247
2331
|
default:
|
|
2248
2332
|
return false;
|
|
2249
2333
|
}
|
|
@@ -2507,6 +2591,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|
|
2507
2591
|
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
|
2508
2592
|
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
|
2509
2593
|
|
|
2594
|
+
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
|
2595
|
+
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
|
2596
|
+
|
|
2510
2597
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
2511
2598
|
ggml_tensor * node = cgraph->nodes[i];
|
|
2512
2599
|
|
|
@@ -2528,9 +2615,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|
|
2528
2615
|
#endif
|
|
2529
2616
|
}
|
|
2530
2617
|
|
|
2531
|
-
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
|
2532
|
-
// disable CUDA graphs for batch size > 1 for now
|
|
2533
|
-
//
|
|
2618
|
+
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
|
|
2619
|
+
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
|
2620
|
+
// by means of matching node names. See
|
|
2621
|
+
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
|
2622
|
+
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
|
|
2623
|
+
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
|
2534
2624
|
use_cuda_graph = false;
|
|
2535
2625
|
#ifndef NDEBUG
|
|
2536
2626
|
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
|
@@ -2676,6 +2766,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|
|
2676
2766
|
}
|
|
2677
2767
|
#endif
|
|
2678
2768
|
|
|
2769
|
+
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
|
2770
|
+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
|
2771
|
+
return false;
|
|
2772
|
+
}
|
|
2773
|
+
|
|
2774
|
+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
|
2775
|
+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
|
2776
|
+
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
|
|
2777
|
+
|
|
2778
|
+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
|
2779
|
+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
|
2780
|
+
|
|
2781
|
+
//rms norm only supports F32
|
|
2782
|
+
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
|
2783
|
+
mul->src[1]->type != GGML_TYPE_F32 ||
|
|
2784
|
+
mul->type != GGML_TYPE_F32) {
|
|
2785
|
+
return false;
|
|
2786
|
+
}
|
|
2787
|
+
|
|
2788
|
+
//if rms norm is the B operand, then we don't handle broadcast
|
|
2789
|
+
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
|
2790
|
+
return false;
|
|
2791
|
+
}
|
|
2792
|
+
|
|
2793
|
+
//rms_norm kernel assumes contigous rows
|
|
2794
|
+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
|
2795
|
+
return false;
|
|
2796
|
+
}
|
|
2797
|
+
}
|
|
2798
|
+
|
|
2799
|
+
return true;
|
|
2800
|
+
}
|
|
2801
|
+
|
|
2679
2802
|
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
|
2680
2803
|
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
|
2681
2804
|
// flag used to determine whether it is an integrated_gpu
|
|
@@ -2685,6 +2808,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|
|
2685
2808
|
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
|
2686
2809
|
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
|
2687
2810
|
if (!use_cuda_graph || cuda_graph_update_required) {
|
|
2811
|
+
|
|
2688
2812
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
2689
2813
|
ggml_tensor * node = cgraph->nodes[i];
|
|
2690
2814
|
|
|
@@ -2692,6 +2816,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|
|
2692
2816
|
continue;
|
|
2693
2817
|
}
|
|
2694
2818
|
|
|
2819
|
+
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
|
2820
|
+
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
2821
|
+
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
|
2822
|
+
i++;
|
|
2823
|
+
continue;
|
|
2824
|
+
}
|
|
2695
2825
|
#ifndef NDEBUG
|
|
2696
2826
|
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
|
2697
2827
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
|
@@ -3036,11 +3166,24 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3036
3166
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
3037
3167
|
case GGML_UNARY_OP_TANH:
|
|
3038
3168
|
case GGML_UNARY_OP_EXP:
|
|
3169
|
+
case GGML_UNARY_OP_ELU:
|
|
3039
3170
|
return ggml_is_contiguous(op->src[0]);
|
|
3040
3171
|
default:
|
|
3041
3172
|
return false;
|
|
3042
3173
|
}
|
|
3043
3174
|
break;
|
|
3175
|
+
case GGML_OP_GLU:
|
|
3176
|
+
switch (ggml_get_glu_op(op)) {
|
|
3177
|
+
case GGML_GLU_OP_REGLU:
|
|
3178
|
+
case GGML_GLU_OP_GEGLU:
|
|
3179
|
+
case GGML_GLU_OP_SWIGLU:
|
|
3180
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
3181
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
3182
|
+
return ggml_is_contiguous_1(op->src[0]);
|
|
3183
|
+
default:
|
|
3184
|
+
return false;
|
|
3185
|
+
}
|
|
3186
|
+
break;
|
|
3044
3187
|
case GGML_OP_MUL_MAT:
|
|
3045
3188
|
case GGML_OP_MUL_MAT_ID:
|
|
3046
3189
|
{
|
|
@@ -3112,6 +3255,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3112
3255
|
switch (op->src[0]->type) {
|
|
3113
3256
|
case GGML_TYPE_F16:
|
|
3114
3257
|
case GGML_TYPE_F32:
|
|
3258
|
+
case GGML_TYPE_BF16:
|
|
3259
|
+
case GGML_TYPE_I32:
|
|
3115
3260
|
case GGML_TYPE_Q4_0:
|
|
3116
3261
|
case GGML_TYPE_Q4_1:
|
|
3117
3262
|
case GGML_TYPE_Q5_0:
|
|
@@ -3126,17 +3271,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3126
3271
|
{
|
|
3127
3272
|
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
|
3128
3273
|
} break;
|
|
3274
|
+
case GGML_OP_SET_ROWS:
|
|
3275
|
+
{
|
|
3276
|
+
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
|
3277
|
+
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
|
|
3278
|
+
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
|
|
3279
|
+
op->src[0]->type == GGML_TYPE_F32 &&
|
|
3280
|
+
op->src[1]->type == GGML_TYPE_I64;
|
|
3281
|
+
} break;
|
|
3129
3282
|
case GGML_OP_CPY:
|
|
3130
3283
|
{
|
|
3131
3284
|
ggml_type src0_type = op->src[0]->type;
|
|
3132
3285
|
ggml_type src1_type = op->src[1]->type;
|
|
3133
|
-
if (src0_type == GGML_TYPE_F32
|
|
3134
|
-
|
|
3135
|
-
|
|
3136
|
-
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
|
|
3137
|
-
return true;
|
|
3138
|
-
}
|
|
3139
|
-
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
|
3286
|
+
if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
|
|
3287
|
+
(src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
|
|
3288
|
+
) {
|
|
3140
3289
|
return true;
|
|
3141
3290
|
}
|
|
3142
3291
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
|
|
@@ -3172,12 +3321,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3172
3321
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
|
3173
3322
|
return true;
|
|
3174
3323
|
}
|
|
3175
|
-
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
|
3176
|
-
return true;
|
|
3177
|
-
}
|
|
3178
|
-
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
3179
|
-
return true;
|
|
3180
|
-
}
|
|
3181
3324
|
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
|
|
3182
3325
|
return true;
|
|
3183
3326
|
}
|
|
@@ -3241,12 +3384,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3241
3384
|
case GGML_OP_COS:
|
|
3242
3385
|
case GGML_OP_CLAMP:
|
|
3243
3386
|
case GGML_OP_LOG:
|
|
3244
|
-
case GGML_OP_SSM_SCAN:
|
|
3245
|
-
case GGML_OP_SSM_CONV:
|
|
3246
3387
|
return true;
|
|
3388
|
+
case GGML_OP_SSM_SCAN: {
|
|
3389
|
+
if (op->src[3]->ne[0] == 1) {
|
|
3390
|
+
// Mamba2
|
|
3391
|
+
// (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
|
|
3392
|
+
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
|
|
3393
|
+
} else {
|
|
3394
|
+
// Mamba
|
|
3395
|
+
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
|
|
3396
|
+
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
|
|
3397
|
+
}
|
|
3398
|
+
}
|
|
3399
|
+
case GGML_OP_SSM_CONV: {
|
|
3400
|
+
// assumes d_inner % threads == 0
|
|
3401
|
+
return op->src[0]->ne[1] % 128 == 0;
|
|
3402
|
+
}
|
|
3247
3403
|
case GGML_OP_CONT:
|
|
3248
|
-
return
|
|
3404
|
+
return true;
|
|
3249
3405
|
case GGML_OP_DIAG_MASK_INF:
|
|
3406
|
+
return true;
|
|
3250
3407
|
case GGML_OP_SOFT_MAX:
|
|
3251
3408
|
return true;
|
|
3252
3409
|
case GGML_OP_SOFT_MAX_BACK: {
|
|
@@ -3271,7 +3428,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3271
3428
|
case GGML_OP_GROUP_NORM:
|
|
3272
3429
|
return ggml_is_contiguous(op->src[0]);
|
|
3273
3430
|
case GGML_OP_UPSCALE:
|
|
3274
|
-
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
3275
3431
|
case GGML_OP_PAD:
|
|
3276
3432
|
case GGML_OP_ARANGE:
|
|
3277
3433
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
@@ -3295,9 +3451,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3295
3451
|
if (op->src[0]->ne[0] == 192) {
|
|
3296
3452
|
return false;
|
|
3297
3453
|
}
|
|
3298
|
-
if (op->src[0]->ne[3] != 1) {
|
|
3299
|
-
return false;
|
|
3300
|
-
}
|
|
3301
3454
|
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
|
3302
3455
|
return false;
|
|
3303
3456
|
}
|
|
@@ -3310,6 +3463,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3310
3463
|
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
|
3311
3464
|
return true;
|
|
3312
3465
|
}
|
|
3466
|
+
if (op->src[3] && op->src[3]->ne[2] != 1) {
|
|
3467
|
+
return false;
|
|
3468
|
+
}
|
|
3313
3469
|
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
|
3314
3470
|
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
|
3315
3471
|
}
|