@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
#include "scale.cuh"
|
|
2
2
|
|
|
3
|
-
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
|
|
3
|
+
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
|
|
4
4
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
5
5
|
|
|
6
6
|
if (i >= k) {
|
|
7
7
|
return;
|
|
8
8
|
}
|
|
9
9
|
|
|
10
|
-
dst[i] = scale * x[i];
|
|
10
|
+
dst[i] = scale * x[i] + bias;
|
|
11
11
|
}
|
|
12
12
|
|
|
13
|
-
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
|
13
|
+
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
|
|
14
14
|
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
|
15
|
-
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
|
15
|
+
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
|
|
16
16
|
}
|
|
17
17
|
|
|
18
18
|
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
@@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
25
25
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
26
26
|
|
|
27
27
|
float scale;
|
|
28
|
-
|
|
28
|
+
float bias;
|
|
29
|
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
30
|
+
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
|
|
29
31
|
|
|
30
|
-
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
|
|
32
|
+
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
|
|
31
33
|
}
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
#include "ggml.h"
|
|
3
3
|
#include "softmax.cuh"
|
|
4
4
|
#include <cstdint>
|
|
5
|
+
#include <utility>
|
|
5
6
|
|
|
6
7
|
template <typename T>
|
|
7
8
|
static __device__ __forceinline__ float t2f32(T val) {
|
|
@@ -13,6 +14,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|
|
13
14
|
return __half2float(val);
|
|
14
15
|
}
|
|
15
16
|
|
|
17
|
+
struct soft_max_params {
|
|
18
|
+
|
|
19
|
+
int64_t nheads;
|
|
20
|
+
uint32_t n_head_log2;
|
|
21
|
+
int64_t ncols;
|
|
22
|
+
int64_t nrows_x;
|
|
23
|
+
int64_t nrows_y;
|
|
24
|
+
int64_t ne00;
|
|
25
|
+
int64_t ne01;
|
|
26
|
+
int64_t ne02;
|
|
27
|
+
int64_t ne03;
|
|
28
|
+
int64_t nb11;
|
|
29
|
+
int64_t nb12;
|
|
30
|
+
int64_t nb13;
|
|
31
|
+
|
|
32
|
+
int64_t ne12;
|
|
33
|
+
int64_t ne13;
|
|
34
|
+
float scale;
|
|
35
|
+
float max_bias;
|
|
36
|
+
float m0;
|
|
37
|
+
float m1;
|
|
38
|
+
};
|
|
39
|
+
|
|
16
40
|
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
|
17
41
|
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
|
18
42
|
#ifdef __clang__
|
|
@@ -21,16 +45,24 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|
|
21
45
|
#endif // __clang__
|
|
22
46
|
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
|
23
47
|
static __global__ void soft_max_f32(
|
|
24
|
-
const float * x, const T * mask, float * dst, const
|
|
25
|
-
|
|
26
|
-
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
|
48
|
+
const float * x, const T * mask, float * dst, const soft_max_params p) {
|
|
49
|
+
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
|
27
50
|
|
|
28
51
|
const int tid = threadIdx.x;
|
|
29
|
-
|
|
30
|
-
const
|
|
52
|
+
|
|
53
|
+
const int64_t i03 = blockIdx.z;
|
|
54
|
+
const int64_t i02 = blockIdx.y;
|
|
55
|
+
const int64_t i01 = blockIdx.x;
|
|
56
|
+
|
|
57
|
+
//TODO: noncontigous inputs/outputs
|
|
58
|
+
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
|
|
59
|
+
|
|
60
|
+
const int64_t i11 = i01;
|
|
61
|
+
const int64_t i12 = i02 % p.ne12;
|
|
62
|
+
const int64_t i13 = i03 % p.ne13;
|
|
31
63
|
|
|
32
64
|
x += int64_t(rowx)*ncols;
|
|
33
|
-
mask +=
|
|
65
|
+
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
|
|
34
66
|
dst += int64_t(rowx)*ncols;
|
|
35
67
|
|
|
36
68
|
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
|
@@ -38,7 +70,7 @@ static __global__ void soft_max_f32(
|
|
|
38
70
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
|
39
71
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
|
40
72
|
|
|
41
|
-
const float slope = get_alibi_slope(max_bias,
|
|
73
|
+
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
|
|
42
74
|
|
|
43
75
|
extern __shared__ float data_soft_max_f32[];
|
|
44
76
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
|
@@ -55,7 +87,7 @@ static __global__ void soft_max_f32(
|
|
|
55
87
|
break;
|
|
56
88
|
}
|
|
57
89
|
|
|
58
|
-
const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
|
90
|
+
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
|
59
91
|
|
|
60
92
|
vals[col] = val;
|
|
61
93
|
max_val = max(max_val, val);
|
|
@@ -150,64 +182,58 @@ static __global__ void soft_max_back_f32(
|
|
|
150
182
|
}
|
|
151
183
|
}
|
|
152
184
|
|
|
185
|
+
template<int... Ns, typename T>
|
|
186
|
+
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
|
187
|
+
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
|
188
|
+
{
|
|
189
|
+
const int id = ggml_cuda_get_device();
|
|
190
|
+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
191
|
+
|
|
192
|
+
auto launch_kernel = [=](auto I) -> bool {
|
|
193
|
+
constexpr int ncols = decltype(I)::value;
|
|
194
|
+
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
|
195
|
+
|
|
196
|
+
if (p.ncols == ncols) {
|
|
197
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
|
198
|
+
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
199
|
+
(x, mask, dst, p);
|
|
200
|
+
return true;
|
|
201
|
+
}
|
|
202
|
+
return false;
|
|
203
|
+
};
|
|
204
|
+
|
|
205
|
+
// unary fold over launch_kernel
|
|
206
|
+
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
|
207
|
+
return;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
//default case
|
|
211
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
|
212
|
+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
|
|
153
216
|
template<typename T>
|
|
154
|
-
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const
|
|
217
|
+
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
|
155
218
|
int nth = WARP_SIZE;
|
|
219
|
+
const int64_t ncols_x = params.ncols;
|
|
220
|
+
|
|
156
221
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
|
157
222
|
const dim3 block_dims(nth, 1, 1);
|
|
158
|
-
const dim3 block_nums(
|
|
223
|
+
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
|
|
159
224
|
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
|
160
225
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
|
161
226
|
|
|
162
|
-
const uint32_t n_head = nrows_x/nrows_y;
|
|
163
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
164
227
|
|
|
165
|
-
const
|
|
166
|
-
const
|
|
228
|
+
const int id = ggml_cuda_get_device();
|
|
229
|
+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
167
230
|
|
|
168
|
-
|
|
169
|
-
if (nbytes_shared
|
|
170
|
-
|
|
171
|
-
case 32:
|
|
172
|
-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
173
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
174
|
-
break;
|
|
175
|
-
case 64:
|
|
176
|
-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
177
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
178
|
-
break;
|
|
179
|
-
case 128:
|
|
180
|
-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
181
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
182
|
-
break;
|
|
183
|
-
case 256:
|
|
184
|
-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
185
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
186
|
-
break;
|
|
187
|
-
case 512:
|
|
188
|
-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
189
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
190
|
-
break;
|
|
191
|
-
case 1024:
|
|
192
|
-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
193
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
194
|
-
break;
|
|
195
|
-
case 2048:
|
|
196
|
-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
197
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
198
|
-
break;
|
|
199
|
-
case 4096:
|
|
200
|
-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
201
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
202
|
-
break;
|
|
203
|
-
default:
|
|
204
|
-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
205
|
-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
206
|
-
break;
|
|
207
|
-
}
|
|
231
|
+
|
|
232
|
+
if (nbytes_shared <= smpbo) {
|
|
233
|
+
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
|
208
234
|
} else {
|
|
209
235
|
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
|
210
|
-
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst,
|
|
236
|
+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
|
211
237
|
}
|
|
212
238
|
}
|
|
213
239
|
|
|
@@ -235,10 +261,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
235
261
|
|
|
236
262
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
|
237
263
|
|
|
238
|
-
const int64_t ne00 = src0->ne[0];
|
|
239
264
|
const int64_t nrows_x = ggml_nrows(src0);
|
|
240
265
|
const int64_t nrows_y = src0->ne[1];
|
|
241
266
|
|
|
267
|
+
const int64_t ne00 = src0->ne[0];
|
|
268
|
+
|
|
242
269
|
float scale = 1.0f;
|
|
243
270
|
float max_bias = 0.0f;
|
|
244
271
|
|
|
@@ -247,10 +274,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
247
274
|
|
|
248
275
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
249
276
|
|
|
277
|
+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
|
278
|
+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
|
279
|
+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
|
280
|
+
|
|
281
|
+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
|
282
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
|
283
|
+
|
|
284
|
+
const uint32_t n_head = src0->ne[2];
|
|
285
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
286
|
+
|
|
287
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
288
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
soft_max_params params = {};
|
|
292
|
+
params.nheads = src0->ne[2];
|
|
293
|
+
params.n_head_log2 = n_head_log2;
|
|
294
|
+
params.ncols = ne00;
|
|
295
|
+
params.nrows_x = nrows_x;
|
|
296
|
+
params.nrows_y = nrows_y;
|
|
297
|
+
params.ne00 = src0->ne[0];
|
|
298
|
+
params.ne01 = src0->ne[1];
|
|
299
|
+
params.ne02 = src0->ne[2];
|
|
300
|
+
params.ne03 = src0->ne[3];
|
|
301
|
+
params.nb11 = nb11;
|
|
302
|
+
params.nb12 = nb12;
|
|
303
|
+
params.nb13 = nb13;
|
|
304
|
+
params.ne12 = ne12;
|
|
305
|
+
params.ne13 = ne13;
|
|
306
|
+
params.scale = scale;
|
|
307
|
+
params.max_bias = max_bias;
|
|
308
|
+
params.m0 = m0;
|
|
309
|
+
params.m1 = m1;
|
|
310
|
+
|
|
250
311
|
if (use_f16) {
|
|
251
|
-
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d,
|
|
312
|
+
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
|
|
252
313
|
} else {
|
|
253
|
-
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d,
|
|
314
|
+
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
|
|
254
315
|
}
|
|
255
316
|
}
|
|
256
317
|
|
|
@@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|
|
107
107
|
if (nc == 4) {
|
|
108
108
|
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
|
109
109
|
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
|
110
|
+
} else if (nc == 3) {
|
|
111
|
+
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
|
112
|
+
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
|
110
113
|
} else {
|
|
111
|
-
GGML_ABORT("Only support kernel size = 4
|
|
114
|
+
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
|
112
115
|
}
|
|
113
116
|
} else {
|
|
114
117
|
if (nc == 4) {
|
|
@@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|
|
116
119
|
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
|
117
120
|
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
|
|
118
121
|
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
|
122
|
+
} else if (nc == 3) {
|
|
123
|
+
const int64_t split_n_t = 32;
|
|
124
|
+
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
|
125
|
+
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
|
|
126
|
+
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
|
119
127
|
} else {
|
|
120
|
-
GGML_ABORT("Only support kernel size = 4 right now.");
|
|
128
|
+
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
|
121
129
|
}
|
|
122
130
|
}
|
|
123
131
|
}
|
|
@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
|
|
|
4
4
|
__global__ void __launch_bounds__(splitD, 2)
|
|
5
5
|
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
|
6
6
|
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
|
7
|
-
const
|
|
8
|
-
const int
|
|
9
|
-
const int
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
GGML_UNUSED(src2_nb0);
|
|
7
|
+
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
|
8
|
+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
|
9
|
+
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
|
10
|
+
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
|
11
|
+
const int64_t s_off, const int64_t d_inner, const int64_t L) {
|
|
13
12
|
|
|
14
13
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
15
|
-
const int bidx = blockIdx.x; // split along B
|
|
16
|
-
const int bidy = blockIdx.y; // split along D
|
|
14
|
+
const int bidx = blockIdx.x; // split along B (sequences)
|
|
15
|
+
const int bidy = blockIdx.y; // split along D (d_inner)
|
|
17
16
|
const int tid = threadIdx.x;
|
|
18
17
|
const int wid = tid / 32;
|
|
19
18
|
const int wtid = tid % 32;
|
|
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
|
24
23
|
float * smem_A = smem;
|
|
25
24
|
float * smem_s0 = smem_A + splitD * stride_sA;
|
|
26
25
|
|
|
27
|
-
const float * s0_block = (const float *) ((const char *) src0 + bidx *
|
|
28
|
-
const float * x_block = (const float *) ((const char *) src1 + (bidx *
|
|
26
|
+
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
|
|
27
|
+
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
|
|
29
28
|
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
|
|
30
29
|
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
|
|
31
|
-
const float * B_block = (const float *) ((const char *) src4 + (bidx *
|
|
32
|
-
const float * C_block = (const float *) ((const char *) src5 + (bidx *
|
|
33
|
-
float * y_block = (float *) ((char *) dst + (bidx *
|
|
34
|
-
float * s_block = (float *) ((char *) dst +
|
|
30
|
+
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
|
|
31
|
+
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
|
|
32
|
+
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
|
|
33
|
+
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
|
|
35
34
|
|
|
36
|
-
const int stride_s0 =
|
|
37
|
-
const int stride_x =
|
|
35
|
+
const int stride_s0 = src0_nb2 / sizeof(float);
|
|
36
|
+
const int stride_x = src1_nb2 / sizeof(float);
|
|
38
37
|
const int stride_dt = src2_nb1 / sizeof(float);
|
|
39
38
|
const int stride_A = src3_nb1 / sizeof(float);
|
|
40
|
-
const int stride_B =
|
|
41
|
-
const int stride_C =
|
|
39
|
+
const int stride_B = src4_nb2 / sizeof(float);
|
|
40
|
+
const int stride_C = src5_nb2 / sizeof(float);
|
|
42
41
|
const int stride_s = stride_s0;
|
|
43
|
-
const int stride_y =
|
|
42
|
+
const int stride_y = d_inner;
|
|
44
43
|
|
|
45
44
|
// can N not be 16? for example 32?
|
|
46
45
|
if (N == 16) {
|
|
@@ -84,24 +83,167 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
|
84
83
|
}
|
|
85
84
|
}
|
|
86
85
|
|
|
86
|
+
// assumes as many threads as d_state
|
|
87
|
+
template <int splitH, int d_state>
|
|
88
|
+
__global__ void __launch_bounds__(d_state, 1)
|
|
89
|
+
ssm_scan_f32_group(
|
|
90
|
+
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
|
91
|
+
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
|
92
|
+
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
|
93
|
+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
|
94
|
+
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
|
95
|
+
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
|
96
|
+
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
|
|
97
|
+
|
|
98
|
+
const int head_idx = (blockIdx.x * splitH) / d_head;
|
|
99
|
+
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
|
|
100
|
+
const int seq_idx = blockIdx.y;
|
|
101
|
+
|
|
102
|
+
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
|
|
103
|
+
|
|
104
|
+
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
|
105
|
+
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
|
|
106
|
+
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
|
107
|
+
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
|
108
|
+
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
|
109
|
+
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
|
110
|
+
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
|
|
111
|
+
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
|
112
|
+
|
|
113
|
+
// strides across n_seq_tokens
|
|
114
|
+
const int stride_x = src1_nb2 / sizeof(float);
|
|
115
|
+
const int stride_dt = src2_nb1 / sizeof(float);
|
|
116
|
+
const int stride_B = src4_nb2 / sizeof(float);
|
|
117
|
+
const int stride_C = src5_nb2 / sizeof(float);
|
|
118
|
+
const int stride_y = n_head * d_head;
|
|
119
|
+
|
|
120
|
+
float state[splitH];
|
|
121
|
+
// for the parallel accumulation
|
|
122
|
+
__shared__ float stateC[splitH * d_state];
|
|
123
|
+
|
|
124
|
+
#pragma unroll
|
|
125
|
+
for (int j = 0; j < splitH; j++) {
|
|
126
|
+
state[j] = s0_block[j * d_state + threadIdx.x];
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
for (int64_t i = 0; i < n_tok; i++) {
|
|
130
|
+
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
|
|
131
|
+
// TODO: only calculate B and C once per head group
|
|
132
|
+
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
|
|
133
|
+
float dt_soft_plus = dt_block[i * stride_dt];
|
|
134
|
+
if (dt_soft_plus <= 20.0f) {
|
|
135
|
+
dt_soft_plus = log1pf(expf(dt_soft_plus));
|
|
136
|
+
}
|
|
137
|
+
const float dA = expf(dt_soft_plus * A_block[0]);
|
|
138
|
+
const float B = B_block[i * stride_B + threadIdx.x];
|
|
139
|
+
const float C = C_block[i * stride_C + threadIdx.x];
|
|
140
|
+
|
|
141
|
+
// across d_head
|
|
142
|
+
#pragma unroll
|
|
143
|
+
for (int j = 0; j < splitH; j++) {
|
|
144
|
+
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
|
|
145
|
+
|
|
146
|
+
state[j] = (state[j] * dA) + (B * x_dt);
|
|
147
|
+
|
|
148
|
+
stateC[j * d_state + threadIdx.x] = state[j] * C;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
__syncthreads();
|
|
152
|
+
|
|
153
|
+
// parallel accumulation for stateC
|
|
154
|
+
// TODO: simplify
|
|
155
|
+
{
|
|
156
|
+
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
|
|
157
|
+
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
|
|
158
|
+
|
|
159
|
+
// reduce until w matches the warp size
|
|
160
|
+
// TODO: does this work even when the physical warp size is 64?
|
|
161
|
+
#pragma unroll
|
|
162
|
+
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
|
|
163
|
+
// (assuming there are d_state threads)
|
|
164
|
+
#pragma unroll
|
|
165
|
+
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
|
|
166
|
+
// TODO: check for bank conflicts
|
|
167
|
+
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
|
|
168
|
+
stateC[k] += stateC[k + (w >> 1)];
|
|
169
|
+
|
|
170
|
+
}
|
|
171
|
+
__syncthreads();
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
static_assert(splitH >= d_state / WARP_SIZE);
|
|
175
|
+
|
|
176
|
+
#pragma unroll
|
|
177
|
+
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
|
|
178
|
+
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
|
|
179
|
+
y = warp_reduce_sum(y);
|
|
180
|
+
|
|
181
|
+
// store the above accumulations
|
|
182
|
+
if (threadIdx.x % WARP_SIZE == 0) {
|
|
183
|
+
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
|
|
184
|
+
y_block[i * stride_y + k] = y;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
// write back the state
|
|
191
|
+
#pragma unroll
|
|
192
|
+
for (int j = 0; j < splitH; j++) {
|
|
193
|
+
s_block[j * d_state + threadIdx.x] = state[j];
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
|
|
87
197
|
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
|
|
88
|
-
const float * src4, const float * src5, const
|
|
89
|
-
const int
|
|
90
|
-
const int
|
|
91
|
-
const int
|
|
92
|
-
|
|
198
|
+
const float * src4, const float * src5, const int32_t * src6, float * dst,
|
|
199
|
+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
|
|
200
|
+
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
|
|
201
|
+
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
|
202
|
+
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
|
93
203
|
cudaStream_t stream) {
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
204
|
+
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
|
205
|
+
if (src3_nb1 == sizeof(float)) {
|
|
206
|
+
// Mamba-2
|
|
207
|
+
if (d_state == 128) {
|
|
208
|
+
const int threads = 128;
|
|
209
|
+
GGML_ASSERT(d_state % threads == 0);
|
|
210
|
+
// NOTE: can be any power of two between 4 and 64
|
|
211
|
+
const int splitH = 16;
|
|
212
|
+
GGML_ASSERT(head_dim % splitH == 0);
|
|
213
|
+
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
|
214
|
+
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
|
|
215
|
+
src0, src1, src2, src3, src4, src5, src6, dst,
|
|
216
|
+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
|
217
|
+
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
|
218
|
+
} else if (d_state == 256) { // Falcon-H1
|
|
219
|
+
const int threads = 256;
|
|
220
|
+
// NOTE: can be any power of two between 8 and 64
|
|
221
|
+
const int splitH = 16;
|
|
222
|
+
GGML_ASSERT(head_dim % splitH == 0);
|
|
223
|
+
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
|
224
|
+
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
|
|
225
|
+
src0, src1, src2, src3, src4, src5, src6, dst,
|
|
226
|
+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
|
227
|
+
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
|
228
|
+
} else {
|
|
229
|
+
GGML_ABORT("doesn't support d_state!=(128 or 256).");
|
|
230
|
+
}
|
|
103
231
|
} else {
|
|
104
|
-
|
|
232
|
+
const int threads = 128;
|
|
233
|
+
// Mamba-1
|
|
234
|
+
GGML_ASSERT(n_head % threads == 0);
|
|
235
|
+
GGML_ASSERT(head_dim == 1);
|
|
236
|
+
GGML_ASSERT(n_group == 1);
|
|
237
|
+
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
|
|
238
|
+
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
|
|
239
|
+
if (d_state == 16) {
|
|
240
|
+
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
|
|
241
|
+
src0, src1, src2, src3, src4, src5, src6, dst,
|
|
242
|
+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
|
243
|
+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
|
244
|
+
} else {
|
|
245
|
+
GGML_ABORT("doesn't support d_state!=16.");
|
|
246
|
+
}
|
|
105
247
|
}
|
|
106
248
|
}
|
|
107
249
|
|
|
@@ -112,30 +254,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
112
254
|
const struct ggml_tensor * src3 = dst->src[3]; // A
|
|
113
255
|
const struct ggml_tensor * src4 = dst->src[4]; // B
|
|
114
256
|
const struct ggml_tensor * src5 = dst->src[5]; // C
|
|
115
|
-
|
|
116
|
-
// const int64_t d_state = src0->ne[0];
|
|
117
|
-
// const int64_t d_inner = src0->ne[1];
|
|
118
|
-
// const int64_t l = src1->ne[1];
|
|
119
|
-
// const int64_t b = src0->ne[2];
|
|
257
|
+
const struct ggml_tensor * src6 = dst->src[6]; // ids
|
|
120
258
|
|
|
121
259
|
const int64_t nc = src0->ne[0]; // d_state
|
|
122
|
-
const int64_t nr = src0->ne[1]; //
|
|
123
|
-
const int64_t
|
|
124
|
-
const int64_t
|
|
260
|
+
const int64_t nr = src0->ne[1]; // head_dim or 1
|
|
261
|
+
const int64_t nh = src1->ne[1]; // n_head
|
|
262
|
+
const int64_t ng = src4->ne[1]; // n_group
|
|
263
|
+
const int64_t n_t = src1->ne[2]; // number of tokens per sequence
|
|
264
|
+
const int64_t n_s = src1->ne[3]; // number of sequences in the batch
|
|
265
|
+
|
|
266
|
+
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
|
|
125
267
|
|
|
126
|
-
GGML_ASSERT(ggml_nelements(src1) +
|
|
268
|
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
|
|
127
269
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
128
270
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
129
271
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
130
272
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
|
131
273
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
132
274
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
133
|
-
|
|
134
|
-
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
|
|
135
|
-
// required for per-sequence offsets for states
|
|
136
|
-
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
|
|
137
|
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
138
|
-
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
|
|
275
|
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
139
276
|
|
|
140
277
|
const float * src0_d = (const float *) src0->data;
|
|
141
278
|
const float * src1_d = (const float *) src1->data;
|
|
@@ -143,13 +280,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
143
280
|
const float * src3_d = (const float *) src3->data;
|
|
144
281
|
const float * src4_d = (const float *) src4->data;
|
|
145
282
|
const float * src5_d = (const float *) src5->data;
|
|
283
|
+
const int32_t * src6_d = (const int32_t *) src6->data;
|
|
146
284
|
float * dst_d = (float *) dst->data;
|
|
147
285
|
cudaStream_t stream = ctx.stream();
|
|
148
286
|
|
|
149
287
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
288
|
+
GGML_ASSERT(src6->type == GGML_TYPE_I32);
|
|
150
289
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
151
290
|
|
|
152
|
-
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d,
|
|
153
|
-
|
|
154
|
-
|
|
291
|
+
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
|
|
292
|
+
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
|
|
293
|
+
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
|
|
294
|
+
s_off, nc, nr, nh, ng, n_t, n_s, stream);
|
|
155
295
|
}
|
|
@@ -1,25 +1,9 @@
|
|
|
1
1
|
#include "sumrows.cuh"
|
|
2
2
|
|
|
3
|
-
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
|
4
|
-
const int row = blockIdx.x;
|
|
5
|
-
const int col = threadIdx.x;
|
|
6
|
-
|
|
7
|
-
float sum = 0.0f;
|
|
8
|
-
for (int i = col; i < ncols; i += blockDim.x) {
|
|
9
|
-
sum += x[row * ncols + i];
|
|
10
|
-
}
|
|
11
|
-
|
|
12
|
-
sum = warp_reduce_sum(sum);
|
|
13
|
-
|
|
14
|
-
if (col == 0) {
|
|
15
|
-
dst[row] = sum;
|
|
16
|
-
}
|
|
17
|
-
}
|
|
18
|
-
|
|
19
3
|
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
20
4
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
21
5
|
const dim3 block_nums(nrows, 1, 1);
|
|
22
|
-
|
|
6
|
+
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
|
23
7
|
}
|
|
24
8
|
|
|
25
9
|
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
35
19
|
const int64_t ncols = src0->ne[0];
|
|
36
20
|
const int64_t nrows = ggml_nrows(src0);
|
|
37
21
|
|
|
38
|
-
|
|
22
|
+
const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
23
|
+
const dim3 block_nums(nrows, 1, 1);
|
|
24
|
+
|
|
25
|
+
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
|
39
26
|
}
|