@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
#include "conv2d-dw.cuh"
|
|
2
|
+
|
|
3
|
+
struct conv_params {
|
|
4
|
+
int in_w, in_h;
|
|
5
|
+
int out_w, out_h;
|
|
6
|
+
int kernel_w, kernel_h;
|
|
7
|
+
int stride_x, stride_y;
|
|
8
|
+
int padding_x, padding_y;
|
|
9
|
+
int dilation_x, dilation_y;
|
|
10
|
+
int channels, batches;
|
|
11
|
+
};
|
|
12
|
+
|
|
13
|
+
struct kernel_bounds {
|
|
14
|
+
int y_min, y_max;
|
|
15
|
+
int x_min, x_max;
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
|
|
19
|
+
kernel_bounds bounds;
|
|
20
|
+
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
|
|
21
|
+
bounds.y_max =
|
|
22
|
+
min(params.kernel_h,
|
|
23
|
+
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
|
|
24
|
+
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
|
|
25
|
+
bounds.x_max =
|
|
26
|
+
min(params.kernel_w,
|
|
27
|
+
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
|
|
28
|
+
return bounds;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
|
|
32
|
+
return out_coord * stride + kern_coord * dilation - padding;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
struct whcn_layout {
|
|
36
|
+
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
|
|
37
|
+
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
|
|
41
|
+
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
|
|
45
|
+
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
|
|
46
|
+
y * params.out_w + x;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
|
|
50
|
+
int & out_x) {
|
|
51
|
+
out_x = global_idx % params.out_w;
|
|
52
|
+
out_y = (global_idx / params.out_w) % params.out_h;
|
|
53
|
+
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
|
|
54
|
+
n = global_idx / (params.out_w * params.out_h * params.channels);
|
|
55
|
+
}
|
|
56
|
+
};
|
|
57
|
+
|
|
58
|
+
struct cwhn_layout {
|
|
59
|
+
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
|
|
60
|
+
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
|
|
64
|
+
return (ky * params.kernel_w + kx) * params.channels + c;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
|
|
68
|
+
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
|
|
69
|
+
x * params.channels + c;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
|
|
73
|
+
int & out_x) {
|
|
74
|
+
c = global_idx % params.channels;
|
|
75
|
+
out_x = (global_idx / params.channels) % params.out_w;
|
|
76
|
+
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
|
|
77
|
+
n = global_idx / (params.channels * params.out_w * params.out_h);
|
|
78
|
+
}
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
template <typename T, typename Layout>
|
|
82
|
+
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
|
|
83
|
+
const int in_w, const int in_h, const int out_w, const int out_h,
|
|
84
|
+
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
|
|
85
|
+
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
|
|
86
|
+
const int channels, const int batches) {
|
|
87
|
+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
88
|
+
const int total_elements = batches * channels * out_h * out_w;
|
|
89
|
+
|
|
90
|
+
if (global_idx >= total_elements) {
|
|
91
|
+
return;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
|
|
95
|
+
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
|
|
96
|
+
|
|
97
|
+
int batch_idx, channel_idx, out_y_idx, out_x_idx;
|
|
98
|
+
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
|
|
99
|
+
|
|
100
|
+
T accumulator = 0;
|
|
101
|
+
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
|
|
102
|
+
|
|
103
|
+
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
|
|
104
|
+
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
|
|
105
|
+
|
|
106
|
+
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
|
|
107
|
+
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
|
|
108
|
+
|
|
109
|
+
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
|
|
110
|
+
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
|
|
111
|
+
|
|
112
|
+
accumulator += input_val * kernel_val;
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
120
|
+
const ggml_tensor * kernel = dst->src[0];
|
|
121
|
+
const ggml_tensor * input = dst->src[1];
|
|
122
|
+
|
|
123
|
+
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
|
124
|
+
const float * w_d = (const float *) kernel->data;
|
|
125
|
+
const float * x_d = (const float *) input->data;
|
|
126
|
+
float * y_d = (float *) dst->data;
|
|
127
|
+
|
|
128
|
+
const int32_t * p = (const int32_t *) dst->op_params;
|
|
129
|
+
const int stride_x = p[0];
|
|
130
|
+
const int stride_y = p[1];
|
|
131
|
+
const int padding_x = p[2];
|
|
132
|
+
const int padding_y = p[3];
|
|
133
|
+
const int dilation_x = p[4];
|
|
134
|
+
const int dilation_y = p[5];
|
|
135
|
+
|
|
136
|
+
const int in_w = input->ne[0];
|
|
137
|
+
const int in_h = input->ne[1];
|
|
138
|
+
const int kernel_w = kernel->ne[0];
|
|
139
|
+
const int kernel_h = kernel->ne[1];
|
|
140
|
+
const int out_w = dst->ne[0];
|
|
141
|
+
const int out_h = dst->ne[1];
|
|
142
|
+
const int channels = dst->ne[2];
|
|
143
|
+
const int batches = dst->ne[3];
|
|
144
|
+
|
|
145
|
+
cudaStream_t st = ctx.stream();
|
|
146
|
+
|
|
147
|
+
const int total = batches * channels * out_h * out_w;
|
|
148
|
+
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
|
|
149
|
+
|
|
150
|
+
if (ggml_is_contiguous(input)) {
|
|
151
|
+
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
|
|
152
|
+
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
|
|
153
|
+
dilation_x, dilation_y, channels, batches);
|
|
154
|
+
} else if (ggml_is_contiguous_channels(input)) {
|
|
155
|
+
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
|
|
156
|
+
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
|
|
157
|
+
dilation_x, dilation_y, channels, batches);
|
|
158
|
+
} else {
|
|
159
|
+
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
|
|
160
|
+
}
|
|
161
|
+
}
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
#include <algorithm>
|
|
2
|
+
|
|
3
|
+
#include "conv2d-transpose.cuh"
|
|
4
|
+
#include "ggml.h"
|
|
5
|
+
|
|
6
|
+
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
|
|
7
|
+
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
|
|
8
|
+
const int out_h, const int kernel_w, const int kernel_h, const int stride,
|
|
9
|
+
const int c_in, const int c_out, const int batches) {
|
|
10
|
+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
11
|
+
|
|
12
|
+
const int total_elements = out_w * out_h * c_out * batches;
|
|
13
|
+
|
|
14
|
+
if (global_idx >= total_elements) {
|
|
15
|
+
return;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
const int out_x_idx = global_idx % out_w;
|
|
19
|
+
const int out_y_idx = (global_idx / out_w) % out_h;
|
|
20
|
+
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
|
|
21
|
+
const int n_idx = global_idx / (out_w * out_h * c_out);
|
|
22
|
+
|
|
23
|
+
float accumulator = 0;
|
|
24
|
+
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
|
|
25
|
+
|
|
26
|
+
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
|
|
27
|
+
for (int kh = 0; kh < kernel_h; ++kh) {
|
|
28
|
+
int in_y = out_y_idx - kh;
|
|
29
|
+
if (in_y < 0 || in_y % stride) continue;
|
|
30
|
+
in_y /= stride;
|
|
31
|
+
if (in_y >= in_h) continue;
|
|
32
|
+
|
|
33
|
+
for (int kw = 0; kw < kernel_w; ++kw) {
|
|
34
|
+
int in_x = out_x_idx - kw;
|
|
35
|
+
if (in_x < 0 || in_x % stride) continue;
|
|
36
|
+
in_x /= stride;
|
|
37
|
+
if (in_x >= in_w) continue;
|
|
38
|
+
|
|
39
|
+
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
|
|
40
|
+
const int kernel_idx =
|
|
41
|
+
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
|
|
42
|
+
|
|
43
|
+
float input_val = input[input_idx];
|
|
44
|
+
half kern_val = kernel[kernel_idx];
|
|
45
|
+
|
|
46
|
+
accumulator += input_val * (float) kern_val;
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
|
|
55
|
+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
56
|
+
const ggml_tensor * kernel = dst->src[0];
|
|
57
|
+
const ggml_tensor * input = dst->src[1];
|
|
58
|
+
|
|
59
|
+
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
|
60
|
+
|
|
61
|
+
const float * input_data = (const float *) input->data;
|
|
62
|
+
float * output_data = (float *) dst->data;
|
|
63
|
+
const half * kernel_data = (const half *) kernel->data;
|
|
64
|
+
|
|
65
|
+
const int input_w = input->ne[0];
|
|
66
|
+
const int input_h = input->ne[1];
|
|
67
|
+
const int output_w = dst->ne[0];
|
|
68
|
+
const int output_h = dst->ne[1];
|
|
69
|
+
const int channels_in = input->ne[2];
|
|
70
|
+
const int channels_out = kernel->ne[2];
|
|
71
|
+
const int kernel_w = kernel->ne[0];
|
|
72
|
+
const int kernel_h = kernel->ne[1];
|
|
73
|
+
const int stride = dst->op_params[0];
|
|
74
|
+
const int batches = input->ne[3];
|
|
75
|
+
|
|
76
|
+
GGML_ASSERT(channels_in == kernel->ne[3]);
|
|
77
|
+
GGML_ASSERT(stride > 0);
|
|
78
|
+
|
|
79
|
+
cudaStream_t st = ctx.stream();
|
|
80
|
+
|
|
81
|
+
GGML_ASSERT(ggml_is_contiguous(input));
|
|
82
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
|
83
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
84
|
+
|
|
85
|
+
const int total = (output_w * output_h * channels_out * batches);
|
|
86
|
+
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
|
|
87
|
+
|
|
88
|
+
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
|
|
89
|
+
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
|
|
90
|
+
channels_in, channels_out, batches);
|
|
91
|
+
}
|
|
@@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
|
|
|
728
728
|
return nullptr;
|
|
729
729
|
}
|
|
730
730
|
}
|
|
731
|
+
|
|
732
|
+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
|
|
733
|
+
switch (type) {
|
|
734
|
+
case GGML_TYPE_F32:
|
|
735
|
+
return convert_unary_cuda<float, nv_bfloat16>;
|
|
736
|
+
case GGML_TYPE_F16:
|
|
737
|
+
return convert_unary_cuda<half, nv_bfloat16>;
|
|
738
|
+
default:
|
|
739
|
+
return nullptr;
|
|
740
|
+
}
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
|
|
744
|
+
switch (type) {
|
|
745
|
+
case GGML_TYPE_F16:
|
|
746
|
+
return convert_unary_cuda<half, float>;
|
|
747
|
+
case GGML_TYPE_BF16:
|
|
748
|
+
return convert_unary_cuda<nv_bfloat16, float>;
|
|
749
|
+
default:
|
|
750
|
+
return nullptr;
|
|
751
|
+
}
|
|
752
|
+
}
|
|
@@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
|
|
|
22
22
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
|
23
23
|
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
|
|
24
24
|
|
|
25
|
+
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
|
|
25
26
|
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
|
|
27
|
+
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
|
|
28
|
+
|
|
29
|
+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
|
|
26
30
|
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
|
|
31
|
+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
|
|
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
|
123
123
|
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
|
124
124
|
|
|
125
125
|
if (nbytes_shared <= smpbo) {
|
|
126
|
-
|
|
127
|
-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
128
|
-
if (!shared_memory_limit_raised[id]) {
|
|
129
|
-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
130
|
-
shared_memory_limit_raised[id] = true;
|
|
131
|
-
}
|
|
132
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
126
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
|
|
133
127
|
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
|
134
128
|
} else {
|
|
135
129
|
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
|
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
|
|
175
169
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
176
170
|
|
|
177
171
|
if (nbytes_shared <= smpbo) {
|
|
178
|
-
|
|
179
|
-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
180
|
-
if (!shared_memory_limit_raised[id]) {
|
|
181
|
-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
182
|
-
shared_memory_limit_raised[id] = true;
|
|
183
|
-
}
|
|
184
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
172
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
|
|
185
173
|
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
|
186
174
|
} else {
|
|
187
175
|
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
|
@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
|
|
|
32
32
|
const int ne12,
|
|
33
33
|
const int ne13,
|
|
34
34
|
const int ne31,
|
|
35
|
+
const int ne32,
|
|
35
36
|
const int nb31,
|
|
37
|
+
const int nb32,
|
|
36
38
|
const int nb01,
|
|
37
39
|
const int nb02,
|
|
38
40
|
const int nb03,
|
|
@@ -851,7 +853,8 @@ void launch_fattn(
|
|
|
851
853
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
852
854
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
853
855
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
854
|
-
mask ? mask->ne[1] : 0, mask ?
|
|
856
|
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
|
857
|
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
|
855
858
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
856
859
|
nb11, nb12, nb13,
|
|
857
860
|
nb21, nb22, nb23,
|
|
@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1223
1223
|
const int ne12,
|
|
1224
1224
|
const int ne13,
|
|
1225
1225
|
const int ne31,
|
|
1226
|
+
const int ne32,
|
|
1226
1227
|
const int nb31,
|
|
1228
|
+
const int nb32,
|
|
1227
1229
|
const int nb01,
|
|
1228
1230
|
const int nb02,
|
|
1229
1231
|
const int nb03,
|
|
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1288
1290
|
|
|
1289
1291
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
|
1290
1292
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
1291
|
-
const half2 * mask_h2 = ncols2
|
|
1293
|
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
1294
|
+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
|
1292
1295
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
1293
1296
|
|
|
1294
1297
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1327
1330
|
|
|
1328
1331
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
|
1329
1332
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
1330
|
-
const half2 * mask_h2 = ncols2
|
|
1333
|
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
1334
|
+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
|
1331
1335
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
1332
1336
|
|
|
1333
1337
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1348
1352
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
1349
1353
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
|
1350
1354
|
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
|
1351
|
-
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
1352
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
1355
|
+
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
1356
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
1353
1357
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
|
1354
1358
|
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
1355
1359
|
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
8
8
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
9
|
-
__launch_bounds__(nwarps*WARP_SIZE,
|
|
9
|
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
10
10
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
11
11
|
static __global__ void flash_attn_tile_ext_f16(
|
|
12
12
|
const char * __restrict__ Q,
|
|
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
30
30
|
const int ne12,
|
|
31
31
|
const int ne13,
|
|
32
32
|
const int ne31,
|
|
33
|
+
const int ne32,
|
|
33
34
|
const int nb31,
|
|
35
|
+
const int nb32,
|
|
34
36
|
const int nb01,
|
|
35
37
|
const int nb02,
|
|
36
38
|
const int nb03,
|
|
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
64
66
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
|
65
67
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
|
66
68
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
|
67
|
-
const half * maskh = (const half *)
|
|
69
|
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
68
70
|
|
|
69
71
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
70
72
|
|
|
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
288
290
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
289
291
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
290
292
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
291
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
292
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
293
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
294
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
293
295
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
294
296
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
295
297
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
8
8
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
9
|
-
__launch_bounds__(nwarps*WARP_SIZE,
|
|
9
|
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
10
10
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
11
11
|
static __global__ void flash_attn_tile_ext_f32(
|
|
12
12
|
const char * __restrict__ Q,
|
|
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
30
30
|
const int ne12,
|
|
31
31
|
const int ne13,
|
|
32
32
|
const int ne31,
|
|
33
|
+
const int ne32,
|
|
33
34
|
const int nb31,
|
|
35
|
+
const int nb32,
|
|
34
36
|
const int nb01,
|
|
35
37
|
const int nb02,
|
|
36
38
|
const int nb03,
|
|
@@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
58
60
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
59
61
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
60
62
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
61
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
62
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
63
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
64
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
63
65
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
64
66
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
65
67
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
76
78
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
|
77
79
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
|
78
80
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
|
79
|
-
const half * maskh = (const half *)
|
|
81
|
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
80
82
|
|
|
81
83
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
82
84
|
|
|
@@ -297,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
297
299
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
|
298
300
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
299
301
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
300
|
-
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
301
|
-
GGML_UNUSED(
|
|
302
|
-
GGML_UNUSED(
|
|
303
|
-
GGML_UNUSED(nb31); GGML_UNUSED(
|
|
304
|
-
GGML_UNUSED(
|
|
305
|
-
GGML_UNUSED(
|
|
306
|
-
GGML_UNUSED(
|
|
307
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
302
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
303
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
304
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
305
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
|
306
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
307
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
308
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
309
|
+
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
308
310
|
NO_DEVICE_CODE;
|
|
309
311
|
#endif // FLASH_ATTN_AVAILABLE
|
|
310
312
|
}
|
|
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
27
27
|
const int ne12,
|
|
28
28
|
const int ne13,
|
|
29
29
|
const int ne31,
|
|
30
|
+
const int ne32,
|
|
30
31
|
const int nb31,
|
|
32
|
+
const int nb32,
|
|
31
33
|
const int nb01,
|
|
32
34
|
const int nb02,
|
|
33
35
|
const int nb03,
|
|
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
68
70
|
K += nb12*(blockIdx.z / gqa_ratio);
|
|
69
71
|
V += nb22*(blockIdx.z / gqa_ratio);
|
|
70
72
|
|
|
71
|
-
const half * maskh = (const half
|
|
73
|
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
72
74
|
|
|
73
75
|
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
|
74
76
|
const half slopeh = __float2half(slopef);
|
|
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
342
344
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
343
345
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
344
346
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
345
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
346
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
347
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
348
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
347
349
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
348
350
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
349
351
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
27
27
|
const int ne12,
|
|
28
28
|
const int ne13,
|
|
29
29
|
const int ne31,
|
|
30
|
+
const int ne32,
|
|
30
31
|
const int nb31,
|
|
32
|
+
const int nb32,
|
|
31
33
|
const int nb01,
|
|
32
34
|
const int nb02,
|
|
33
35
|
const int nb03,
|
|
@@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
51
53
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
52
54
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
53
55
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
54
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
55
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
56
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
57
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
56
58
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
57
59
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
58
60
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
79
81
|
Q += nb02* blockIdx.z + nb01*ic0;
|
|
80
82
|
K += nb12*(blockIdx.z / gqa_ratio);
|
|
81
83
|
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
|
82
|
-
|
|
84
|
+
|
|
85
|
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
83
86
|
|
|
84
87
|
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
|
85
88
|
|
|
@@ -334,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
334
337
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
335
338
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
|
336
339
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
337
|
-
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
338
|
-
GGML_UNUSED(
|
|
339
|
-
GGML_UNUSED(
|
|
340
|
-
GGML_UNUSED(
|
|
341
|
-
GGML_UNUSED(
|
|
342
|
-
GGML_UNUSED(
|
|
343
|
-
GGML_UNUSED(
|
|
340
|
+
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
341
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
342
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
343
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
344
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
|
345
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
346
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
347
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
348
|
+
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
344
349
|
NO_DEVICE_CODE;
|
|
345
350
|
#endif // FLASH_ATTN_AVAILABLE
|
|
346
351
|
}
|
|
@@ -9,7 +9,11 @@
|
|
|
9
9
|
#ifdef FP16_MMA_AVAILABLE
|
|
10
10
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
11
11
|
#include <mma.h>
|
|
12
|
+
#ifdef GGML_USE_MUSA
|
|
13
|
+
namespace wmma = mtmusa::wmma;
|
|
14
|
+
#else // GGML_USE_MUSA
|
|
12
15
|
namespace wmma = nvcuda::wmma;
|
|
16
|
+
#endif // GGML_USE_MUSA
|
|
13
17
|
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
|
|
14
18
|
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
|
|
15
19
|
#include <rocwmma/rocwmma.hpp>
|
|
@@ -42,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
42
46
|
const int ne12,
|
|
43
47
|
const int ne13,
|
|
44
48
|
const int ne31,
|
|
49
|
+
const int ne32,
|
|
45
50
|
const int nb31,
|
|
51
|
+
const int nb32,
|
|
46
52
|
const int nb01,
|
|
47
53
|
const int nb02,
|
|
48
54
|
const int nb03,
|
|
@@ -90,11 +96,11 @@ static __global__ void flash_attn_ext_f16(
|
|
|
90
96
|
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
|
91
97
|
|
|
92
98
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
93
|
-
const float * Q_f = (const float *) (Q
|
|
94
|
-
const half * K_h = (const half *) (K
|
|
95
|
-
const half * V_h = (const half *) (V
|
|
96
|
-
const half * maskh = (const half *)
|
|
97
|
-
const half2 * mask2 = (const half2 *)
|
|
99
|
+
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
|
100
|
+
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
|
101
|
+
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
|
102
|
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
103
|
+
const half2 * mask2 = (const half2 *) maskh;
|
|
98
104
|
|
|
99
105
|
const int stride_Q = nb01 / sizeof(float);
|
|
100
106
|
const int stride_KV = nb11 / sizeof(half);
|
|
@@ -436,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
436
442
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
437
443
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
438
444
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
439
|
-
GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
445
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
440
446
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
441
447
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
442
448
|
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
|
|
168
168
|
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
|
169
169
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
170
170
|
break;
|
|
171
|
+
case GGML_TYPE_I32:
|
|
172
|
+
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
|
|
173
|
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
174
|
+
break;
|
|
171
175
|
case GGML_TYPE_BF16:
|
|
172
176
|
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
|
173
177
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
@@ -210,6 +214,10 @@ void get_rows_cuda(
|
|
|
210
214
|
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
|
211
215
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
212
216
|
break;
|
|
217
|
+
case GGML_TYPE_I32:
|
|
218
|
+
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
|
|
219
|
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
220
|
+
break;
|
|
213
221
|
case GGML_TYPE_F16:
|
|
214
222
|
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
|
215
223
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|