@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "ggml-common.h"
|
|
4
|
+
|
|
5
|
+
template<typename src_t, typename dst_t>
|
|
6
|
+
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
|
|
7
|
+
if constexpr (std::is_same_v<src_t, dst_t>) {
|
|
8
|
+
*dst = *src;
|
|
9
|
+
} else {
|
|
10
|
+
*dst = float(*src);
|
|
11
|
+
}
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
|
15
|
+
if (x <= val[0]) return 0;
|
|
16
|
+
if (x >= val[n-1]) return n-1;
|
|
17
|
+
int ml = 0, mu = n-1;
|
|
18
|
+
while (mu-ml > 1) {
|
|
19
|
+
int mav = (ml+mu)/2;
|
|
20
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
|
21
|
+
}
|
|
22
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
|
|
26
|
+
float amax = 0.0f;
|
|
27
|
+
float vmax = 0.0f;
|
|
28
|
+
|
|
29
|
+
for (int j = 0; j < QK4_0; ++j) {
|
|
30
|
+
const float v = x[j];
|
|
31
|
+
if (amax < fabsf(v)) {
|
|
32
|
+
amax = fabsf(v);
|
|
33
|
+
vmax = v;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
const float d = vmax / -8;
|
|
38
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
39
|
+
|
|
40
|
+
y->d = d;
|
|
41
|
+
|
|
42
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
|
43
|
+
const float x0 = x[0 + j]*id;
|
|
44
|
+
const float x1 = x[QK4_0/2 + j]*id;
|
|
45
|
+
|
|
46
|
+
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
|
|
47
|
+
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
|
|
48
|
+
|
|
49
|
+
y->qs[j] = xi0;
|
|
50
|
+
y->qs[j] |= xi1 << 4;
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
|
|
55
|
+
float vmin = FLT_MAX;
|
|
56
|
+
float vmax = -FLT_MAX;
|
|
57
|
+
|
|
58
|
+
for (int j = 0; j < QK4_1; ++j) {
|
|
59
|
+
const float v = x[j];
|
|
60
|
+
if (v < vmin) vmin = v;
|
|
61
|
+
if (v > vmax) vmax = v;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
|
65
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
66
|
+
|
|
67
|
+
y->dm.x = d;
|
|
68
|
+
y->dm.y = vmin;
|
|
69
|
+
|
|
70
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
|
71
|
+
const float x0 = (x[0 + j] - vmin)*id;
|
|
72
|
+
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
|
|
73
|
+
|
|
74
|
+
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
|
|
75
|
+
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
|
|
76
|
+
|
|
77
|
+
y->qs[j] = xi0;
|
|
78
|
+
y->qs[j] |= xi1 << 4;
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
|
|
83
|
+
float amax = 0.0f;
|
|
84
|
+
float vmax = 0.0f;
|
|
85
|
+
|
|
86
|
+
for (int j = 0; j < QK5_0; ++j) {
|
|
87
|
+
const float v = x[j];
|
|
88
|
+
if (amax < fabsf(v)) {
|
|
89
|
+
amax = fabsf(v);
|
|
90
|
+
vmax = v;
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
const float d = vmax / -16;
|
|
95
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
96
|
+
|
|
97
|
+
y->d = d;
|
|
98
|
+
|
|
99
|
+
uint32_t qh = 0;
|
|
100
|
+
for (int j = 0; j < QK5_0/2; ++j) {
|
|
101
|
+
const float x0 = x[0 + j]*id;
|
|
102
|
+
const float x1 = x[QK5_0/2 + j]*id;
|
|
103
|
+
|
|
104
|
+
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
|
|
105
|
+
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
|
|
106
|
+
|
|
107
|
+
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
108
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
109
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
110
|
+
}
|
|
111
|
+
memcpy(y->qh, &qh, sizeof(qh));
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
|
|
115
|
+
float min = x[0];
|
|
116
|
+
float max = x[0];
|
|
117
|
+
|
|
118
|
+
for (int j = 1; j < QK5_1; ++j) {
|
|
119
|
+
const float v = x[j];
|
|
120
|
+
min = v < min ? v : min;
|
|
121
|
+
max = v > max ? v : max;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
const float d = (max - min) / 31;
|
|
125
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
126
|
+
|
|
127
|
+
y->dm.x = d;
|
|
128
|
+
y->dm.y = min;
|
|
129
|
+
|
|
130
|
+
uint32_t qh = 0;
|
|
131
|
+
for (int j = 0; j < QK5_1/2; ++j) {
|
|
132
|
+
const float x0 = (x[0 + j] - min)*id;
|
|
133
|
+
const float x1 = (x[QK5_1/2 + j] - min)*id;
|
|
134
|
+
|
|
135
|
+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
|
136
|
+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
|
137
|
+
|
|
138
|
+
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
139
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
140
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
|
141
|
+
}
|
|
142
|
+
memcpy(y->qh, &qh, sizeof(qh));
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
|
|
146
|
+
float amax = 0.0f; // absolute max
|
|
147
|
+
|
|
148
|
+
for (int j = 0; j < QK8_0; j++) {
|
|
149
|
+
const float v = x[j];
|
|
150
|
+
amax = fmaxf(amax, fabsf(v));
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
const float d = amax / ((1 << 7) - 1);
|
|
154
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
155
|
+
|
|
156
|
+
y->d = d;
|
|
157
|
+
|
|
158
|
+
for (int j = 0; j < QK8_0; ++j) {
|
|
159
|
+
const float x0 = x[j]*id;
|
|
160
|
+
y->qs[j] = roundf(x0);
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
|
|
165
|
+
float amax = 0.0f;
|
|
166
|
+
float vmax = 0.0f;
|
|
167
|
+
|
|
168
|
+
for (int j = 0; j < QK4_NL; ++j) {
|
|
169
|
+
const float v = x[j];
|
|
170
|
+
if (amax < fabsf(v)) {
|
|
171
|
+
amax = fabsf(v);
|
|
172
|
+
vmax = v;
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
float d = vmax / kvalues_iq4nl[0];
|
|
177
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
178
|
+
|
|
179
|
+
float sumqx = 0, sumq2 = 0;
|
|
180
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
|
181
|
+
const float x0 = x[0 + j]*id;
|
|
182
|
+
const float x1 = x[QK4_NL/2 + j]*id;
|
|
183
|
+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
|
|
184
|
+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
|
|
185
|
+
y->qs[j] = xi0 | (xi1 << 4);
|
|
186
|
+
const float v0 = kvalues_iq4nl[xi0];
|
|
187
|
+
const float v1 = kvalues_iq4nl[xi1];
|
|
188
|
+
const float w0 = x[0 + j]*x[0 + j];
|
|
189
|
+
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
|
|
190
|
+
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
|
|
191
|
+
sumq2 += w0*v0*v0 + w1*v1*v1;
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
// Wrapper functions for cpy.cu compatibility
|
|
198
|
+
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
|
199
|
+
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
|
203
|
+
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
|
207
|
+
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
|
211
|
+
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|
215
|
+
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
|
219
|
+
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
template<typename src_t, typename dst_t>
|
|
223
|
+
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
|
224
|
+
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
|
|
225
|
+
}
|
|
@@ -1,51 +1,17 @@
|
|
|
1
1
|
#include "cpy.cuh"
|
|
2
2
|
#include "dequantize.cuh"
|
|
3
|
-
#
|
|
3
|
+
#include "cpy-utils.cuh"
|
|
4
|
+
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
|
4
5
|
#include "ggml-musa/mudnn.cuh"
|
|
5
|
-
#endif // GGML_USE_MUSA
|
|
6
|
+
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
|
6
7
|
|
|
7
8
|
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
|
8
9
|
|
|
9
|
-
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
|
10
|
-
const float * xi = (const float *) cxi;
|
|
11
|
-
float * dsti = (float *) cdsti;
|
|
12
|
-
|
|
13
|
-
*dsti = *xi;
|
|
14
|
-
}
|
|
15
|
-
|
|
16
|
-
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
|
|
17
|
-
const float * xi = (const float *) cxi;
|
|
18
|
-
nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
|
|
19
|
-
|
|
20
|
-
*dsti = *xi;
|
|
21
|
-
}
|
|
22
|
-
|
|
23
|
-
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
|
24
|
-
const float * xi = (const float *) cxi;
|
|
25
|
-
half * dsti = (half *) cdsti;
|
|
26
|
-
|
|
27
|
-
*dsti = __float2half(*xi);
|
|
28
|
-
}
|
|
29
|
-
|
|
30
|
-
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
|
31
|
-
const half * xi = (const half *) cxi;
|
|
32
|
-
half * dsti = (half *) cdsti;
|
|
33
|
-
|
|
34
|
-
*dsti = *xi;
|
|
35
|
-
}
|
|
36
|
-
|
|
37
|
-
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
|
38
|
-
const half * xi = (const half *) cxi;
|
|
39
|
-
float * dsti = (float *) cdsti;
|
|
40
|
-
|
|
41
|
-
*dsti = *xi;
|
|
42
|
-
}
|
|
43
|
-
|
|
44
10
|
template <cpy_kernel_t cpy_1>
|
|
45
|
-
static __global__ void
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
11
|
+
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
|
|
12
|
+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
13
|
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
|
14
|
+
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
|
49
15
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
50
16
|
|
|
51
17
|
if (i >= ne) {
|
|
@@ -71,29 +37,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
|
|
|
71
37
|
cpy_1(cx + x_offset, cdst + dst_offset);
|
|
72
38
|
}
|
|
73
39
|
|
|
74
|
-
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|
75
|
-
const float * xi = (const float *) cxi;
|
|
76
|
-
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
|
77
|
-
|
|
78
|
-
float amax = 0.0f; // absolute max
|
|
79
|
-
|
|
80
|
-
for (int j = 0; j < QK8_0; j++) {
|
|
81
|
-
const float v = xi[j];
|
|
82
|
-
amax = fmaxf(amax, fabsf(v));
|
|
83
|
-
}
|
|
84
|
-
|
|
85
|
-
const float d = amax / ((1 << 7) - 1);
|
|
86
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
87
|
-
|
|
88
|
-
dsti->d = d;
|
|
89
|
-
|
|
90
|
-
for (int j = 0; j < QK8_0; ++j) {
|
|
91
|
-
const float x0 = xi[j]*id;
|
|
92
|
-
|
|
93
|
-
dsti->qs[j] = roundf(x0);
|
|
94
|
-
}
|
|
95
|
-
}
|
|
96
|
-
|
|
97
40
|
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
|
98
41
|
float * cdstf = (float *)(cdsti);
|
|
99
42
|
|
|
@@ -106,139 +49,6 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
|
|
106
49
|
}
|
|
107
50
|
}
|
|
108
51
|
|
|
109
|
-
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
|
110
|
-
const float * xi = (const float *) cxi;
|
|
111
|
-
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
|
112
|
-
|
|
113
|
-
float amax = 0.0f;
|
|
114
|
-
float vmax = 0.0f;
|
|
115
|
-
|
|
116
|
-
for (int j = 0; j < QK4_0; ++j) {
|
|
117
|
-
const float v = xi[j];
|
|
118
|
-
if (amax < fabsf(v)) {
|
|
119
|
-
amax = fabsf(v);
|
|
120
|
-
vmax = v;
|
|
121
|
-
}
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
const float d = vmax / -8;
|
|
125
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
126
|
-
|
|
127
|
-
dsti->d = d;
|
|
128
|
-
|
|
129
|
-
for (int j = 0; j < QK4_0/2; ++j) {
|
|
130
|
-
const float x0 = xi[0 + j]*id;
|
|
131
|
-
const float x1 = xi[QK4_0/2 + j]*id;
|
|
132
|
-
|
|
133
|
-
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
|
|
134
|
-
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
|
|
135
|
-
|
|
136
|
-
dsti->qs[j] = xi0;
|
|
137
|
-
dsti->qs[j] |= xi1 << 4;
|
|
138
|
-
}
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
|
142
|
-
const float * xi = (const float *) cxi;
|
|
143
|
-
block_q4_1 * dsti = (block_q4_1 *) cdsti;
|
|
144
|
-
|
|
145
|
-
float vmin = FLT_MAX;
|
|
146
|
-
float vmax = -FLT_MAX;
|
|
147
|
-
|
|
148
|
-
for (int j = 0; j < QK4_1; ++j) {
|
|
149
|
-
const float v = xi[j];
|
|
150
|
-
|
|
151
|
-
if (v < vmin) vmin = v;
|
|
152
|
-
if (v > vmax) vmax = v;
|
|
153
|
-
}
|
|
154
|
-
|
|
155
|
-
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
|
156
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
157
|
-
|
|
158
|
-
dsti->dm.x = d;
|
|
159
|
-
dsti->dm.y = vmin;
|
|
160
|
-
|
|
161
|
-
for (int j = 0; j < QK4_1/2; ++j) {
|
|
162
|
-
const float x0 = (xi[0 + j] - vmin)*id;
|
|
163
|
-
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
|
|
164
|
-
|
|
165
|
-
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
|
|
166
|
-
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
|
|
167
|
-
|
|
168
|
-
dsti->qs[j] = xi0;
|
|
169
|
-
dsti->qs[j] |= xi1 << 4;
|
|
170
|
-
}
|
|
171
|
-
}
|
|
172
|
-
|
|
173
|
-
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
|
174
|
-
const float * xi = (const float *) cxi;
|
|
175
|
-
block_q5_0 * dsti = (block_q5_0 *) cdsti;
|
|
176
|
-
|
|
177
|
-
float amax = 0.0f;
|
|
178
|
-
float vmax = 0.0f;
|
|
179
|
-
|
|
180
|
-
for (int j = 0; j < QK5_0; ++j) {
|
|
181
|
-
const float v = xi[j];
|
|
182
|
-
if (amax < fabsf(v)) {
|
|
183
|
-
amax = fabsf(v);
|
|
184
|
-
vmax = v;
|
|
185
|
-
}
|
|
186
|
-
}
|
|
187
|
-
|
|
188
|
-
const float d = vmax / -16;
|
|
189
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
190
|
-
|
|
191
|
-
dsti->d = d;
|
|
192
|
-
|
|
193
|
-
uint32_t qh = 0;
|
|
194
|
-
for (int j = 0; j < QK5_0/2; ++j) {
|
|
195
|
-
const float x0 = xi[0 + j]*id;
|
|
196
|
-
const float x1 = xi[QK5_0/2 + j]*id;
|
|
197
|
-
|
|
198
|
-
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
|
|
199
|
-
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
|
|
200
|
-
|
|
201
|
-
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
202
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
203
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
204
|
-
}
|
|
205
|
-
memcpy(dsti->qh, &qh, sizeof(qh));
|
|
206
|
-
}
|
|
207
|
-
|
|
208
|
-
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
|
209
|
-
const float * xi = (const float *) cxi;
|
|
210
|
-
block_q5_1 * dsti = (block_q5_1 *) cdsti;
|
|
211
|
-
|
|
212
|
-
float min = xi[0];
|
|
213
|
-
float max = xi[0];
|
|
214
|
-
|
|
215
|
-
for (int j = 1; j < QK5_1; ++j) {
|
|
216
|
-
const float v = xi[j];
|
|
217
|
-
min = v < min ? v : min;
|
|
218
|
-
max = v > max ? v : max;
|
|
219
|
-
}
|
|
220
|
-
|
|
221
|
-
const float d = (max - min) / 31;
|
|
222
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
223
|
-
|
|
224
|
-
dsti->dm.x = d;
|
|
225
|
-
dsti->dm.y = min;
|
|
226
|
-
|
|
227
|
-
uint32_t qh = 0;
|
|
228
|
-
for (int j = 0; j < QK5_1/2; ++j) {
|
|
229
|
-
const float x0 = (xi[0 + j] - min)*id;
|
|
230
|
-
const float x1 = (xi[QK5_1/2 + j] - min)*id;
|
|
231
|
-
|
|
232
|
-
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
|
233
|
-
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
|
234
|
-
|
|
235
|
-
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
236
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
237
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
|
238
|
-
}
|
|
239
|
-
memcpy(dsti->qh, &qh, sizeof(qh));
|
|
240
|
-
}
|
|
241
|
-
|
|
242
52
|
template<dequantize_kernel_t dequant, int qk>
|
|
243
53
|
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
|
244
54
|
float * cdstf = (float *)(cdsti);
|
|
@@ -252,53 +62,6 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
|
|
252
62
|
}
|
|
253
63
|
}
|
|
254
64
|
|
|
255
|
-
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
|
256
|
-
if (x <= val[0]) return 0;
|
|
257
|
-
if (x >= val[n-1]) return n-1;
|
|
258
|
-
int ml = 0, mu = n-1;
|
|
259
|
-
while (mu-ml > 1) {
|
|
260
|
-
int mav = (ml+mu)/2;
|
|
261
|
-
if (x < val[mav]) mu = mav; else ml = mav;
|
|
262
|
-
}
|
|
263
|
-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
264
|
-
}
|
|
265
|
-
|
|
266
|
-
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
|
267
|
-
const float * xi = (const float *) cxi;
|
|
268
|
-
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
|
|
269
|
-
|
|
270
|
-
float amax = 0.0f;
|
|
271
|
-
float vmax = 0.0f;
|
|
272
|
-
|
|
273
|
-
for (int j = 0; j < QK4_NL; ++j) {
|
|
274
|
-
const float v = xi[j];
|
|
275
|
-
if (amax < fabsf(v)) {
|
|
276
|
-
amax = fabsf(v);
|
|
277
|
-
vmax = v;
|
|
278
|
-
}
|
|
279
|
-
}
|
|
280
|
-
|
|
281
|
-
float d = vmax / kvalues_iq4nl[0];
|
|
282
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
283
|
-
|
|
284
|
-
float sumqx = 0, sumq2 = 0;
|
|
285
|
-
for (int j = 0; j < QK4_NL/2; ++j) {
|
|
286
|
-
const float x0 = xi[0 + j]*id;
|
|
287
|
-
const float x1 = xi[QK4_NL/2 + j]*id;
|
|
288
|
-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
|
|
289
|
-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
|
|
290
|
-
dsti->qs[j] = xi0 | (xi1 << 4);
|
|
291
|
-
const float v0 = kvalues_iq4nl[xi0];
|
|
292
|
-
const float v1 = kvalues_iq4nl[xi1];
|
|
293
|
-
const float w0 = xi[0 + j]*xi[0 + j];
|
|
294
|
-
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
|
|
295
|
-
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
|
|
296
|
-
sumq2 += w0*v0*v0 + w1*v1*v1;
|
|
297
|
-
}
|
|
298
|
-
|
|
299
|
-
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
|
|
300
|
-
}
|
|
301
|
-
|
|
302
65
|
template <cpy_kernel_t cpy_blck, int qk>
|
|
303
66
|
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
|
304
67
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
@@ -358,7 +121,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
|
|
|
358
121
|
// Copy destination pointers to GPU to be available when pointer indirection is in use
|
|
359
122
|
|
|
360
123
|
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
|
|
361
|
-
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
|
124
|
+
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
|
362
125
|
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
|
363
126
|
CUDA_CHECK(cudaStreamSynchronize(stream));
|
|
364
127
|
if (cuda_graph->dest_ptrs_d != nullptr) {
|
|
@@ -376,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
|
|
|
376
139
|
#endif
|
|
377
140
|
}
|
|
378
141
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
382
|
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
|
383
|
-
|
|
384
|
-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
385
|
-
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
386
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
|
387
|
-
}
|
|
388
|
-
|
|
389
|
-
static void ggml_cpy_f32_f32_cuda(
|
|
142
|
+
template<typename src_t, typename dst_t>
|
|
143
|
+
static void ggml_cpy_flt_cuda(
|
|
390
144
|
const char * cx, char * cdst, const int ne,
|
|
391
145
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
392
146
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
|
393
147
|
|
|
394
148
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
395
|
-
|
|
396
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
|
397
|
-
}
|
|
398
|
-
|
|
399
|
-
static void ggml_cpy_f32_bf16_cuda(
|
|
400
|
-
const char * cx, char * cdst, const int ne,
|
|
401
|
-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
402
|
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
|
403
|
-
|
|
404
|
-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
405
|
-
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
406
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
|
407
|
-
}
|
|
408
|
-
|
|
409
|
-
static void ggml_cpy_f32_f16_cuda(
|
|
410
|
-
const char * cx, char * cdst, const int ne,
|
|
411
|
-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
412
|
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
|
413
|
-
|
|
414
|
-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
415
|
-
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
149
|
+
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
416
150
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
|
417
151
|
}
|
|
418
152
|
|
|
@@ -544,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
|
|
544
278
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
|
545
279
|
}
|
|
546
280
|
|
|
547
|
-
static void ggml_cpy_f16_f16_cuda(
|
|
548
|
-
const char * cx, char * cdst, const int ne,
|
|
549
|
-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
550
|
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
|
551
|
-
|
|
552
|
-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
553
|
-
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
554
|
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
|
555
|
-
}
|
|
556
|
-
|
|
557
281
|
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
|
558
282
|
const int64_t ne = ggml_nelements(src0);
|
|
559
283
|
GGML_ASSERT(ne == ggml_nelements(src1));
|
|
@@ -590,7 +314,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
590
314
|
|
|
591
315
|
char ** dest_ptrs_d = nullptr;
|
|
592
316
|
int graph_cpynode_index = -1;
|
|
593
|
-
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
|
317
|
+
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
|
594
318
|
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
|
595
319
|
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
|
|
596
320
|
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
|
@@ -600,20 +324,20 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
600
324
|
#endif
|
|
601
325
|
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
|
602
326
|
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
|
603
|
-
#
|
|
327
|
+
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
|
604
328
|
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
|
|
605
329
|
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
|
|
606
330
|
} else
|
|
607
|
-
#endif // GGML_USE_MUSA
|
|
331
|
+
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
|
608
332
|
{
|
|
609
333
|
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
|
610
334
|
}
|
|
611
335
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
612
|
-
|
|
336
|
+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
613
337
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
614
|
-
|
|
338
|
+
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
615
339
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
616
|
-
|
|
340
|
+
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
617
341
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
618
342
|
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
619
343
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
|
@@ -640,14 +364,22 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
640
364
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
|
641
365
|
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
642
366
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
643
|
-
|
|
367
|
+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
368
|
+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
|
369
|
+
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
644
370
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
645
|
-
|
|
371
|
+
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
372
|
+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
373
|
+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
374
|
+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
|
375
|
+
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
376
|
+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
|
377
|
+
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
646
378
|
} else {
|
|
647
379
|
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
648
380
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
649
381
|
}
|
|
650
|
-
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
|
382
|
+
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
|
651
383
|
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
|
652
384
|
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
|
653
385
|
}
|
|
@@ -667,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
|
667
399
|
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
|
668
400
|
return nullptr;
|
|
669
401
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
670
|
-
return (void*)
|
|
402
|
+
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
|
671
403
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
672
|
-
return (void*)
|
|
404
|
+
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
|
|
673
405
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
674
|
-
return (void*)
|
|
406
|
+
return (void*) cpy_flt<cpy_1_flt<float, half>>;
|
|
675
407
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
676
408
|
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
|
677
409
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
|
@@ -695,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
|
695
427
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
|
696
428
|
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
|
697
429
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
698
|
-
return (void*)
|
|
430
|
+
return (void*) cpy_flt<cpy_1_flt<half, half>>;
|
|
431
|
+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
|
432
|
+
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
|
|
699
433
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
700
|
-
return (void*)
|
|
434
|
+
return (void*) cpy_flt<cpy_1_flt<half, float>>;
|
|
435
|
+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
|
436
|
+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
|
|
437
|
+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
438
|
+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
|
439
|
+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
|
440
|
+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
|
701
441
|
} else {
|
|
702
442
|
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
703
443
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|