@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -19,6 +19,10 @@ if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
|
19
19
|
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
20
20
|
message(STATUS "Enabling bfloat16 glslc support")
|
|
21
21
|
endif()
|
|
22
|
+
if (GGML_VULKAN_SHADER_DEBUG_INFO)
|
|
23
|
+
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
|
|
24
|
+
message(STATUS "Enabling shader debug info")
|
|
25
|
+
endif()
|
|
22
26
|
|
|
23
27
|
set(TARGET vulkan-shaders-gen)
|
|
24
28
|
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
|
@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
|
|
|
6
6
|
#endif // RTE16
|
|
7
7
|
|
|
8
8
|
#include "types.comp"
|
|
9
|
-
#include "generic_unary_head.comp"
|
|
10
9
|
|
|
11
|
-
#if defined(
|
|
12
|
-
|
|
13
|
-
|
|
10
|
+
#if defined(SET_ROWS) && QUANT_K == 1
|
|
11
|
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
12
|
+
const uint BLOCK_SIZE = 512;
|
|
14
13
|
#else
|
|
15
|
-
layout(local_size_x =
|
|
14
|
+
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
|
15
|
+
const uint BLOCK_SIZE = 32;
|
|
16
16
|
#endif
|
|
17
17
|
|
|
18
18
|
layout (binding = 0) readonly buffer S {float data_s[];};
|
|
19
|
+
|
|
20
|
+
#if defined(SET_ROWS)
|
|
21
|
+
#include "generic_binary_head.comp"
|
|
22
|
+
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
|
|
23
|
+
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
|
|
24
|
+
#else
|
|
25
|
+
#include "generic_unary_head.comp"
|
|
19
26
|
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
|
|
27
|
+
#endif
|
|
20
28
|
|
|
21
29
|
#if defined(DATA_A_Q4_0)
|
|
22
30
|
void quantize(uint dst_idx, uint src_idx)
|
|
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
|
|
|
221
229
|
}
|
|
222
230
|
#endif
|
|
223
231
|
|
|
232
|
+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
|
233
|
+
void quantize(uint dst_idx, uint src_idx)
|
|
234
|
+
{
|
|
235
|
+
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
|
|
236
|
+
}
|
|
237
|
+
#endif
|
|
238
|
+
|
|
239
|
+
#if defined(DATA_A_BF16)
|
|
240
|
+
void quantize(uint dst_idx, uint src_idx)
|
|
241
|
+
{
|
|
242
|
+
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
|
|
243
|
+
}
|
|
244
|
+
#endif
|
|
245
|
+
|
|
246
|
+
#if defined(SET_ROWS)
|
|
247
|
+
|
|
224
248
|
void main() {
|
|
225
249
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
226
250
|
init_iq_shmem(gl_WorkGroupSize);
|
|
227
|
-
|
|
251
|
+
#endif
|
|
252
|
+
|
|
253
|
+
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
|
|
254
|
+
|
|
255
|
+
if (idx >= p.ne) {
|
|
228
256
|
return;
|
|
229
257
|
}
|
|
258
|
+
|
|
259
|
+
uint i00, i01, i02, i03;
|
|
260
|
+
get_indices(idx, i00, i01, i02, i03);
|
|
261
|
+
|
|
262
|
+
uint i12 = fastmod(i03, p.ne12);
|
|
263
|
+
uint i11 = fastmod(i02, p.ne11);
|
|
264
|
+
uint i10 = i01;
|
|
265
|
+
|
|
266
|
+
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
|
|
267
|
+
|
|
268
|
+
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
|
|
269
|
+
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
|
|
270
|
+
|
|
271
|
+
quantize(dst_idx, src0_idx);
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
#else
|
|
275
|
+
|
|
276
|
+
void main() {
|
|
277
|
+
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
278
|
+
init_iq_shmem(gl_WorkGroupSize);
|
|
230
279
|
#endif
|
|
231
280
|
|
|
232
|
-
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
|
|
281
|
+
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
|
|
233
282
|
|
|
234
283
|
if (idx >= p.ne) {
|
|
235
284
|
return;
|
|
@@ -240,3 +289,5 @@ void main() {
|
|
|
240
289
|
|
|
241
290
|
quantize(dst_idx, src_idx);
|
|
242
291
|
}
|
|
292
|
+
|
|
293
|
+
#endif
|
|
@@ -11,7 +11,8 @@
|
|
|
11
11
|
#include "types.comp"
|
|
12
12
|
#include "flash_attn_base.comp"
|
|
13
13
|
|
|
14
|
-
const uint32_t
|
|
14
|
+
const uint32_t HSK_per_thread = HSK / D_split;
|
|
15
|
+
const uint32_t HSV_per_thread = HSV / D_split;
|
|
15
16
|
|
|
16
17
|
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
|
17
18
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
|
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
|
29
30
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
30
31
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
31
32
|
{
|
|
32
|
-
uint32_t offset = (iq2 + r) *
|
|
33
|
+
uint32_t offset = (iq2 + r) * HSV + c;
|
|
33
34
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
34
35
|
return elem;
|
|
35
36
|
}
|
|
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
|
|
38
39
|
shared vec4 tmpshv4[WorkGroupSize];
|
|
39
40
|
|
|
40
41
|
shared float masksh[Bc][Br];
|
|
41
|
-
shared vec4 Qf[Br][
|
|
42
|
+
shared vec4 Qf[Br][HSK / 4];
|
|
42
43
|
|
|
43
44
|
void main() {
|
|
44
45
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
@@ -53,18 +54,18 @@ void main() {
|
|
|
53
54
|
|
|
54
55
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
|
55
56
|
|
|
56
|
-
[[unroll]] for (uint32_t idx = 0; idx < Br *
|
|
57
|
-
uint32_t d = (idx + tid) % (
|
|
58
|
-
uint32_t r = (idx + tid) / (
|
|
59
|
-
if (r < Br && d <
|
|
57
|
+
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
|
58
|
+
uint32_t d = (idx + tid) % (HSK / 4);
|
|
59
|
+
uint32_t r = (idx + tid) / (HSK / 4);
|
|
60
|
+
if (r < Br && d < HSK / 4 &&
|
|
60
61
|
i * Br + r < N) {
|
|
61
62
|
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
|
62
63
|
}
|
|
63
64
|
}
|
|
64
65
|
barrier();
|
|
65
66
|
|
|
66
|
-
vec4 Of[Br][
|
|
67
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
67
|
+
vec4 Of[Br][HSV_per_thread / 4];
|
|
68
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
68
69
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
69
70
|
Of[r][d] = vec4(0.0);
|
|
70
71
|
}
|
|
@@ -99,6 +100,10 @@ void main() {
|
|
|
99
100
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
|
100
101
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
|
101
102
|
#endif
|
|
103
|
+
uint32_t m_offset = 0;
|
|
104
|
+
if (p.nem2 != 1 || p.nem3 != 1) {
|
|
105
|
+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
|
106
|
+
}
|
|
102
107
|
|
|
103
108
|
[[dont_unroll]]
|
|
104
109
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
@@ -112,7 +117,7 @@ void main() {
|
|
|
112
117
|
|
|
113
118
|
|
|
114
119
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
115
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
120
|
+
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
|
116
121
|
#if BLOCK_SIZE > 1
|
|
117
122
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
118
123
|
uint ib = coord / BLOCK_SIZE;
|
|
@@ -144,13 +149,13 @@ void main() {
|
|
|
144
149
|
}
|
|
145
150
|
}
|
|
146
151
|
|
|
147
|
-
if (p.
|
|
152
|
+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
148
153
|
|
|
149
154
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
150
155
|
uint32_t c = (idx + tid) % Bc;
|
|
151
156
|
uint32_t r = (idx + tid) / Bc;
|
|
152
157
|
if (idx + tid < Bc * Br) {
|
|
153
|
-
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
|
|
158
|
+
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
|
154
159
|
}
|
|
155
160
|
}
|
|
156
161
|
barrier();
|
|
@@ -191,14 +196,14 @@ void main() {
|
|
|
191
196
|
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
|
192
197
|
}
|
|
193
198
|
|
|
194
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
199
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
195
200
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
196
201
|
Of[r][d] = eMf[r] * Of[r][d];
|
|
197
202
|
}
|
|
198
203
|
}
|
|
199
204
|
|
|
200
205
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
201
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
206
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
202
207
|
#if BLOCK_SIZE > 1
|
|
203
208
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
204
209
|
uint ib = coord / BLOCK_SIZE;
|
|
@@ -255,7 +260,7 @@ void main() {
|
|
|
255
260
|
Lf[r] = tmpsh[d_tid];
|
|
256
261
|
barrier();
|
|
257
262
|
|
|
258
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
263
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
259
264
|
|
|
260
265
|
Of[r][d] = eMf * Of[r][d];
|
|
261
266
|
tmpshv4[tid] = Of[r][d];
|
|
@@ -277,11 +282,11 @@ void main() {
|
|
|
277
282
|
// If there is split_k, then the split_k resolve shader does the final
|
|
278
283
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
279
284
|
if (p.k_num > 1) {
|
|
280
|
-
uint32_t o_offset =
|
|
285
|
+
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
|
281
286
|
|
|
282
287
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
283
288
|
if (r < N) {
|
|
284
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
289
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
285
290
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
286
291
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
|
287
292
|
}
|
|
@@ -289,7 +294,7 @@ void main() {
|
|
|
289
294
|
}
|
|
290
295
|
}
|
|
291
296
|
|
|
292
|
-
o_offset =
|
|
297
|
+
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
|
293
298
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
294
299
|
if (r < N) {
|
|
295
300
|
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
@@ -305,18 +310,18 @@ void main() {
|
|
|
305
310
|
Lfrcp[r] = 1.0 / Lf[r];
|
|
306
311
|
}
|
|
307
312
|
|
|
308
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
313
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
309
314
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
310
315
|
Of[r][d] *= Lfrcp[r];
|
|
311
316
|
}
|
|
312
317
|
}
|
|
313
318
|
|
|
314
|
-
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
319
|
+
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
|
315
320
|
|
|
316
321
|
if (p.gqa_ratio > 1) {
|
|
317
322
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
318
323
|
if (r < N) {
|
|
319
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
324
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
320
325
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
321
326
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
|
322
327
|
}
|
|
@@ -326,9 +331,9 @@ void main() {
|
|
|
326
331
|
} else {
|
|
327
332
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
328
333
|
if (i * Br + r < N) {
|
|
329
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
334
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
330
335
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
331
|
-
data_o[o_offset + iq2 *
|
|
336
|
+
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
|
332
337
|
}
|
|
333
338
|
}
|
|
334
339
|
}
|
|
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
4
4
|
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
|
5
5
|
layout (constant_id = 1) const uint32_t Br = 1;
|
|
6
6
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
7
|
-
layout (constant_id = 3) const uint32_t
|
|
8
|
-
layout (constant_id = 4) const uint32_t
|
|
9
|
-
layout (constant_id = 5) const uint32_t
|
|
10
|
-
|
|
7
|
+
layout (constant_id = 3) const uint32_t HSK = 32;
|
|
8
|
+
layout (constant_id = 4) const uint32_t HSV = 32;
|
|
9
|
+
layout (constant_id = 5) const uint32_t Clamp = 0;
|
|
10
|
+
layout (constant_id = 6) const uint32_t D_split = 16;
|
|
11
11
|
|
|
12
12
|
layout (push_constant) uniform parameter {
|
|
13
13
|
uint32_t N;
|
|
@@ -24,6 +24,8 @@ layout (push_constant) uniform parameter {
|
|
|
24
24
|
uint32_t nev2;
|
|
25
25
|
uint32_t nev3;
|
|
26
26
|
uint32_t nem1;
|
|
27
|
+
uint32_t nem2;
|
|
28
|
+
uint32_t nem3;
|
|
27
29
|
|
|
28
30
|
uint32_t nb01;
|
|
29
31
|
uint32_t nb02;
|
|
@@ -34,14 +36,12 @@ layout (push_constant) uniform parameter {
|
|
|
34
36
|
uint32_t nb21;
|
|
35
37
|
uint32_t nb22;
|
|
36
38
|
uint32_t nb23;
|
|
37
|
-
uint32_t nb31;
|
|
38
39
|
|
|
39
40
|
float scale;
|
|
40
41
|
float max_bias;
|
|
41
42
|
float logit_softcap;
|
|
42
43
|
|
|
43
|
-
uint32_t
|
|
44
|
-
uint32_t n_head_log2;
|
|
44
|
+
uint32_t mask_n_head_log2;
|
|
45
45
|
float m0;
|
|
46
46
|
float m1;
|
|
47
47
|
|
|
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
|
|
|
50
50
|
uint32_t k_num;
|
|
51
51
|
} p;
|
|
52
52
|
|
|
53
|
+
#define MASK_ENABLE_BIT (1<<16)
|
|
54
|
+
#define N_LOG2_MASK 0xFFFF
|
|
55
|
+
|
|
53
56
|
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
54
57
|
|
|
55
58
|
#if defined(A_TYPE_PACKED16)
|
|
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
|
|
100
103
|
{
|
|
101
104
|
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
|
102
105
|
|
|
103
|
-
|
|
104
|
-
|
|
106
|
+
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
|
|
107
|
+
|
|
108
|
+
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
|
|
109
|
+
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
|
|
105
110
|
|
|
106
111
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
107
112
|
}
|
|
@@ -13,7 +13,9 @@
|
|
|
13
13
|
#include "types.comp"
|
|
14
14
|
#include "flash_attn_base.comp"
|
|
15
15
|
|
|
16
|
-
const uint32_t
|
|
16
|
+
const uint32_t HSK_per_thread = HSK / D_split;
|
|
17
|
+
const uint32_t HSV_per_thread = HSV / D_split;
|
|
18
|
+
|
|
17
19
|
const uint32_t row_split = 4;
|
|
18
20
|
const uint32_t rows_per_thread = Br / row_split;
|
|
19
21
|
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
|
@@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
|
32
34
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
33
35
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
34
36
|
{
|
|
35
|
-
uint32_t offset = (iq2 + r) *
|
|
37
|
+
uint32_t offset = (iq2 + r) * HSV + c;
|
|
36
38
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
37
39
|
return elem;
|
|
38
40
|
}
|
|
@@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
|
|
|
44
46
|
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
|
45
47
|
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
|
46
48
|
|
|
47
|
-
const uint32_t qstride =
|
|
49
|
+
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
|
48
50
|
shared f16vec4 Qf[Br * qstride];
|
|
49
51
|
|
|
50
|
-
// Avoid padding for
|
|
51
|
-
const uint32_t sfshstride = (
|
|
52
|
+
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
|
53
|
+
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
|
52
54
|
shared ACC_TYPE sfsh[Bc * sfshstride];
|
|
53
55
|
|
|
54
|
-
const uint32_t kshstride =
|
|
56
|
+
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
|
55
57
|
shared f16vec4 ksh[Bc * kshstride];
|
|
56
58
|
|
|
57
59
|
shared float slope[Br];
|
|
@@ -74,18 +76,18 @@ void main() {
|
|
|
74
76
|
|
|
75
77
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
|
76
78
|
|
|
77
|
-
[[unroll]] for (uint32_t idx = 0; idx < Br *
|
|
78
|
-
uint32_t d = (idx + tid) % (
|
|
79
|
-
uint32_t r = (idx + tid) / (
|
|
80
|
-
if (r < Br && d <
|
|
79
|
+
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
|
80
|
+
uint32_t d = (idx + tid) % (HSK / 4);
|
|
81
|
+
uint32_t r = (idx + tid) / (HSK / 4);
|
|
82
|
+
if (r < Br && d < HSK / 4 &&
|
|
81
83
|
i * Br + r < N) {
|
|
82
84
|
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
|
83
85
|
}
|
|
84
86
|
}
|
|
85
87
|
barrier();
|
|
86
88
|
|
|
87
|
-
ACC_TYPEV4 Of[rows_per_thread][
|
|
88
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
89
|
+
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
|
90
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
89
91
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
90
92
|
Of[r][d] = ACC_TYPEV4(0.0);
|
|
91
93
|
}
|
|
@@ -123,14 +125,18 @@ void main() {
|
|
|
123
125
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
|
124
126
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
|
125
127
|
#endif
|
|
128
|
+
uint32_t m_offset = 0;
|
|
129
|
+
if (p.nem2 != 1 || p.nem3 != 1) {
|
|
130
|
+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
|
131
|
+
}
|
|
126
132
|
|
|
127
133
|
[[dont_unroll]]
|
|
128
134
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
129
135
|
|
|
130
|
-
[[unroll]] for (uint32_t idx = 0; idx < Bc *
|
|
131
|
-
uint32_t d = (idx + tid) % (
|
|
132
|
-
uint32_t c = (idx + tid) / (
|
|
133
|
-
if (c < Bc && d <
|
|
136
|
+
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
|
137
|
+
uint32_t d = (idx + tid) % (HSK / 4);
|
|
138
|
+
uint32_t c = (idx + tid) / (HSK / 4);
|
|
139
|
+
if (c < Bc && d < HSK / 4) {
|
|
134
140
|
#if BLOCK_SIZE > 1
|
|
135
141
|
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
|
136
142
|
uint ib = coord / BLOCK_SIZE;
|
|
@@ -145,14 +151,14 @@ void main() {
|
|
|
145
151
|
}
|
|
146
152
|
barrier();
|
|
147
153
|
|
|
148
|
-
// K * Q^T -> S^T: Bc x
|
|
149
|
-
// Bc split across workgroup (four subgroups), loop over
|
|
154
|
+
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
|
155
|
+
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
|
150
156
|
// This is written transposed in order to allow for N being 8 if implementations need it
|
|
151
157
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
|
152
158
|
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
|
153
159
|
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
|
154
160
|
|
|
155
|
-
for (uint32_t d = 0; d <
|
|
161
|
+
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
|
156
162
|
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
|
157
163
|
|
|
158
164
|
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
|
@@ -176,12 +182,12 @@ void main() {
|
|
|
176
182
|
barrier();
|
|
177
183
|
}
|
|
178
184
|
|
|
179
|
-
if (p.
|
|
185
|
+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
180
186
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
181
187
|
uint32_t c = (idx + tid) % Bc;
|
|
182
188
|
uint32_t r = (idx + tid) / Bc;
|
|
183
189
|
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
|
184
|
-
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
|
190
|
+
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
|
185
191
|
}
|
|
186
192
|
}
|
|
187
193
|
barrier();
|
|
@@ -202,7 +208,7 @@ void main() {
|
|
|
202
208
|
eMf[r] = exp(Moldf - Mf[r]);
|
|
203
209
|
}
|
|
204
210
|
|
|
205
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
211
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
206
212
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
207
213
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
|
208
214
|
}
|
|
@@ -217,7 +223,7 @@ void main() {
|
|
|
217
223
|
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
|
218
224
|
Lf[r] += Pf[r];
|
|
219
225
|
}
|
|
220
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
226
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
221
227
|
#if BLOCK_SIZE > 1
|
|
222
228
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
223
229
|
uint ib = coord / BLOCK_SIZE;
|
|
@@ -280,7 +286,7 @@ void main() {
|
|
|
280
286
|
}
|
|
281
287
|
|
|
282
288
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
283
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
289
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
284
290
|
|
|
285
291
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
|
286
292
|
tmpshv4[tid] = Of[r][d];
|
|
@@ -300,11 +306,11 @@ void main() {
|
|
|
300
306
|
// If there is split_k, then the split_k resolve shader does the final
|
|
301
307
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
302
308
|
if (p.k_num > 1) {
|
|
303
|
-
uint32_t o_offset =
|
|
309
|
+
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
|
304
310
|
|
|
305
311
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
306
312
|
if (tile_row(r) < N) {
|
|
307
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
313
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
308
314
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
309
315
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
|
310
316
|
}
|
|
@@ -312,7 +318,7 @@ void main() {
|
|
|
312
318
|
}
|
|
313
319
|
}
|
|
314
320
|
|
|
315
|
-
o_offset =
|
|
321
|
+
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
|
316
322
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
317
323
|
if (tile_row(r) < N) {
|
|
318
324
|
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
@@ -328,18 +334,18 @@ void main() {
|
|
|
328
334
|
Lfrcp[r] = 1.0 / Lf[r];
|
|
329
335
|
}
|
|
330
336
|
|
|
331
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
337
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
332
338
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
333
339
|
Of[r][d] *= float16_t(Lfrcp[r]);
|
|
334
340
|
}
|
|
335
341
|
}
|
|
336
342
|
|
|
337
|
-
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
343
|
+
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
|
338
344
|
|
|
339
345
|
if (p.gqa_ratio > 1) {
|
|
340
346
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
341
347
|
if (tile_row(r) < N) {
|
|
342
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
348
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
343
349
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
344
350
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
|
345
351
|
}
|
|
@@ -349,9 +355,9 @@ void main() {
|
|
|
349
355
|
} else {
|
|
350
356
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
351
357
|
if (i * Br + tile_row(r) < N) {
|
|
352
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
358
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
353
359
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
354
|
-
data_o[o_offset + iq2 *
|
|
360
|
+
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
|
355
361
|
}
|
|
356
362
|
}
|
|
357
363
|
}
|