@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -305,6 +305,27 @@ void main() {
|
|
|
305
305
|
return;
|
|
306
306
|
}
|
|
307
307
|
|
|
308
|
+
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
|
309
|
+
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
310
|
+
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
|
311
|
+
|
|
312
|
+
float ms = 1.0f;
|
|
313
|
+
float vs = 1.0f;
|
|
314
|
+
|
|
315
|
+
if (sink > Mf[r]) {
|
|
316
|
+
ms = exp(Mf[r] - sink);
|
|
317
|
+
|
|
318
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
319
|
+
Of[r][d] *= ms;
|
|
320
|
+
}
|
|
321
|
+
} else {
|
|
322
|
+
vs = exp(sink - Mf[r]);
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
Lf[r] = Lf[r]*ms + vs;
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
|
|
308
329
|
float Lfrcp[Br];
|
|
309
330
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
310
331
|
Lfrcp[r] = 1.0 / Lf[r];
|
|
@@ -9,6 +9,10 @@ layout (constant_id = 4) const uint32_t HSV = 32;
|
|
|
9
9
|
layout (constant_id = 5) const uint32_t Clamp = 0;
|
|
10
10
|
layout (constant_id = 6) const uint32_t D_split = 16;
|
|
11
11
|
|
|
12
|
+
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
|
|
13
|
+
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
|
14
|
+
const uint32_t HSV_pad = (HSV + 15) & ~15;
|
|
15
|
+
|
|
12
16
|
layout (push_constant) uniform parameter {
|
|
13
17
|
uint32_t N;
|
|
14
18
|
uint32_t KV;
|
|
@@ -50,10 +54,13 @@ layout (push_constant) uniform parameter {
|
|
|
50
54
|
uint32_t k_num;
|
|
51
55
|
} p;
|
|
52
56
|
|
|
57
|
+
#define SINK_ENABLE_BIT (1<<24)
|
|
53
58
|
#define MASK_ENABLE_BIT (1<<16)
|
|
54
59
|
#define N_LOG2_MASK 0xFFFF
|
|
55
60
|
|
|
56
|
-
layout (binding = 4)
|
|
61
|
+
layout (binding = 4) readonly buffer S {float data_s[];};
|
|
62
|
+
|
|
63
|
+
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
|
57
64
|
|
|
58
65
|
#if defined(A_TYPE_PACKED16)
|
|
59
66
|
#define BINDING_IDX_K 0
|
|
@@ -111,6 +118,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
|
|
111
118
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
112
119
|
}
|
|
113
120
|
|
|
121
|
+
// Load the sink value, indexed by Q's dimension 2.
|
|
122
|
+
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
|
123
|
+
{
|
|
124
|
+
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
|
125
|
+
|
|
126
|
+
return ACC_TYPE(data_s[h]);
|
|
127
|
+
}
|
|
128
|
+
|
|
114
129
|
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
|
115
130
|
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
|
116
131
|
q_stride, k_stride, v_stride, m_stride;
|
|
@@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
|
|
|
46
46
|
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
|
47
47
|
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
|
48
48
|
|
|
49
|
-
const uint32_t qstride =
|
|
49
|
+
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
|
50
50
|
shared f16vec4 Qf[Br * qstride];
|
|
51
51
|
|
|
52
52
|
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
|
53
53
|
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
|
54
54
|
shared ACC_TYPE sfsh[Bc * sfshstride];
|
|
55
55
|
|
|
56
|
-
const uint32_t kshstride =
|
|
56
|
+
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
|
|
57
57
|
shared f16vec4 ksh[Bc * kshstride];
|
|
58
58
|
|
|
59
59
|
shared float slope[Br];
|
|
@@ -74,6 +74,21 @@ void main() {
|
|
|
74
74
|
|
|
75
75
|
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
|
76
76
|
|
|
77
|
+
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
|
|
78
|
+
if ((HSK % 16) != 0) {
|
|
79
|
+
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
|
|
80
|
+
if (i + tid < Br * qstride) {
|
|
81
|
+
Qf[i + tid] = f16vec4(0);
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
|
|
85
|
+
if (i + tid < Bc * kshstride) {
|
|
86
|
+
ksh[i + tid] = f16vec4(0);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
barrier();
|
|
90
|
+
}
|
|
91
|
+
|
|
77
92
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
|
78
93
|
|
|
79
94
|
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
|
@@ -151,14 +166,14 @@ void main() {
|
|
|
151
166
|
}
|
|
152
167
|
barrier();
|
|
153
168
|
|
|
154
|
-
// K * Q^T -> S^T: Bc x
|
|
169
|
+
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
|
|
155
170
|
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
|
156
171
|
// This is written transposed in order to allow for N being 8 if implementations need it
|
|
157
172
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
|
158
173
|
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
|
159
174
|
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
|
160
175
|
|
|
161
|
-
for (uint32_t d = 0; d <
|
|
176
|
+
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
|
162
177
|
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
|
163
178
|
|
|
164
179
|
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
|
@@ -210,7 +225,7 @@ void main() {
|
|
|
210
225
|
|
|
211
226
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
212
227
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
213
|
-
Of[r][d] =
|
|
228
|
+
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
|
214
229
|
}
|
|
215
230
|
}
|
|
216
231
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
@@ -233,7 +248,7 @@ void main() {
|
|
|
233
248
|
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
|
234
249
|
#endif
|
|
235
250
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
236
|
-
Of[r][d] +=
|
|
251
|
+
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
|
|
237
252
|
}
|
|
238
253
|
}
|
|
239
254
|
}
|
|
@@ -288,7 +303,7 @@ void main() {
|
|
|
288
303
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
289
304
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
290
305
|
|
|
291
|
-
Of[r][d] =
|
|
306
|
+
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
|
292
307
|
tmpshv4[tid] = Of[r][d];
|
|
293
308
|
|
|
294
309
|
barrier();
|
|
@@ -329,6 +344,27 @@ void main() {
|
|
|
329
344
|
return;
|
|
330
345
|
}
|
|
331
346
|
|
|
347
|
+
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
|
348
|
+
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
349
|
+
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
|
350
|
+
|
|
351
|
+
float ms = 1.0f;
|
|
352
|
+
float vs = 1.0f;
|
|
353
|
+
|
|
354
|
+
if (sink > Mf[r]) {
|
|
355
|
+
ms = exp(Mf[r] - sink);
|
|
356
|
+
|
|
357
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
358
|
+
Of[r][d] *= ACC_TYPE(ms);
|
|
359
|
+
}
|
|
360
|
+
} else {
|
|
361
|
+
vs = exp(sink - Mf[r]);
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
Lf[r] = Lf[r]*ms + vs;
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
332
368
|
float Lfrcp[rows_per_thread];
|
|
333
369
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
334
370
|
Lfrcp[r] = 1.0 / Lf[r];
|
|
@@ -336,7 +372,7 @@ void main() {
|
|
|
336
372
|
|
|
337
373
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
338
374
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
339
|
-
Of[r][d] *=
|
|
375
|
+
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
|
|
340
376
|
}
|
|
341
377
|
}
|
|
342
378
|
|
|
@@ -104,16 +104,16 @@ void main() {
|
|
|
104
104
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
|
105
105
|
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
|
106
106
|
|
|
107
|
-
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br,
|
|
108
|
-
coopmat<float16_t, gl_ScopeWorkgroup, Br,
|
|
107
|
+
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
|
|
108
|
+
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
|
|
109
109
|
|
|
110
110
|
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
|
111
|
-
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0,
|
|
111
|
+
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
|
|
112
112
|
|
|
113
|
-
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br,
|
|
113
|
+
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
|
|
114
114
|
Qf16 *= float16_t(p.scale);
|
|
115
115
|
|
|
116
|
-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br,
|
|
116
|
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
|
117
117
|
|
|
118
118
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
|
119
119
|
|
|
@@ -140,10 +140,10 @@ void main() {
|
|
|
140
140
|
|
|
141
141
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
|
142
142
|
|
|
143
|
-
coopmat<float16_t, gl_ScopeWorkgroup,
|
|
143
|
+
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
|
144
144
|
|
|
145
145
|
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
|
146
|
-
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0,
|
|
146
|
+
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
|
147
147
|
S = coopMatMulAdd(Qf16, K_T, S);
|
|
148
148
|
|
|
149
149
|
if (p.logit_softcap != 0.0f) {
|
|
@@ -208,31 +208,31 @@ void main() {
|
|
|
208
208
|
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
|
209
209
|
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
|
210
210
|
|
|
211
|
-
coopmat<float16_t, gl_ScopeWorkgroup, Bc,
|
|
211
|
+
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
|
212
212
|
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
|
213
|
-
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0,
|
|
213
|
+
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
|
|
214
214
|
|
|
215
215
|
L = eM*L + rowsum;
|
|
216
216
|
|
|
217
217
|
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
|
218
218
|
// multiply rather than matrix multiply it has the diagonal element smeared
|
|
219
219
|
// across the row
|
|
220
|
-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br,
|
|
220
|
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
|
|
221
221
|
|
|
222
222
|
// resize eM by using smear/reduce
|
|
223
223
|
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
224
224
|
|
|
225
225
|
// multiply with fp16 accumulation, then add to O.
|
|
226
|
-
coopmat<float16_t, gl_ScopeWorkgroup, Br,
|
|
226
|
+
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
|
227
227
|
PV = coopMatMulAdd(P_A, V, PV);
|
|
228
228
|
|
|
229
|
-
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br,
|
|
229
|
+
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
|
|
230
230
|
}
|
|
231
231
|
|
|
232
232
|
// If there is split_k, then the split_k resolve shader does the final
|
|
233
233
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
234
234
|
if (p.k_num > 1) {
|
|
235
|
-
coopmat<D_TYPE, gl_ScopeWorkgroup, Br,
|
|
235
|
+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
|
236
236
|
|
|
237
237
|
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
|
238
238
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
|
@@ -243,11 +243,39 @@ void main() {
|
|
|
243
243
|
return;
|
|
244
244
|
}
|
|
245
245
|
|
|
246
|
-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br,
|
|
246
|
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
|
|
247
247
|
|
|
248
248
|
// resize L by using smear/reduce
|
|
249
249
|
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
250
250
|
|
|
251
|
+
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
|
252
|
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
|
|
253
|
+
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
|
|
254
|
+
|
|
255
|
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
|
|
256
|
+
|
|
257
|
+
// resize M by using smear/reduce
|
|
258
|
+
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
259
|
+
|
|
260
|
+
// O, Ldiag, Mr all have the same type so all element locations match
|
|
261
|
+
[[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
|
|
262
|
+
ACC_TYPE sink = S[i];
|
|
263
|
+
|
|
264
|
+
ACC_TYPE ms = ACC_TYPE(1.0f);
|
|
265
|
+
ACC_TYPE vs = ACC_TYPE(1.0f);
|
|
266
|
+
|
|
267
|
+
if (sink > Mr[i]) {
|
|
268
|
+
ms = exp(Mr[i] - sink);
|
|
269
|
+
|
|
270
|
+
O[i] *= ms;
|
|
271
|
+
} else {
|
|
272
|
+
vs = exp(sink - Mr[i]);
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
Ldiag[i] = Ldiag[i]*ms + vs;
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
|
|
251
279
|
[[unroll]]
|
|
252
280
|
for (int k = 0; k < Ldiag.length(); ++k) {
|
|
253
281
|
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
|
@@ -257,7 +285,7 @@ void main() {
|
|
|
257
285
|
|
|
258
286
|
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
|
259
287
|
|
|
260
|
-
coopmat<D_TYPE, gl_ScopeWorkgroup, Br,
|
|
288
|
+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
|
261
289
|
if (p.gqa_ratio > 1) {
|
|
262
290
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
|
263
291
|
} else {
|
|
@@ -267,6 +295,6 @@ void main() {
|
|
|
267
295
|
// permute dimensions
|
|
268
296
|
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
|
269
297
|
|
|
270
|
-
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0,
|
|
298
|
+
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
|
|
271
299
|
}
|
|
272
300
|
}
|
|
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
|
7
7
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
8
8
|
|
|
9
9
|
layout (binding = 0) readonly buffer A {float data_a[];};
|
|
10
|
-
layout (binding = 1)
|
|
10
|
+
layout (binding = 1) readonly buffer B {float data_s[];};
|
|
11
|
+
layout (binding = 2) writeonly buffer D {float data_d[];};
|
|
11
12
|
|
|
12
13
|
layout (push_constant) uniform parameter {
|
|
13
14
|
uint D;
|
|
14
15
|
uint N;
|
|
15
16
|
uint ne3;
|
|
16
17
|
uint k_num;
|
|
18
|
+
uint sinks;
|
|
17
19
|
} p;
|
|
18
20
|
|
|
19
21
|
shared float tmpsh[BLOCK_SIZE];
|
|
@@ -73,6 +75,22 @@ void main() {
|
|
|
73
75
|
}
|
|
74
76
|
L = tmpsh[0];
|
|
75
77
|
|
|
78
|
+
float sink;
|
|
79
|
+
if (p.sinks != 0) {
|
|
80
|
+
sink = data_s[n];
|
|
81
|
+
|
|
82
|
+
float ms = 1.0f;
|
|
83
|
+
float vs = 1.0f;
|
|
84
|
+
|
|
85
|
+
if (sink > m_max) {
|
|
86
|
+
ms = exp(m_max - sink);
|
|
87
|
+
} else {
|
|
88
|
+
vs = exp(sink - m_max);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
L = L*ms + vs;
|
|
92
|
+
}
|
|
93
|
+
|
|
76
94
|
L = 1.0 / L;
|
|
77
95
|
|
|
78
96
|
// D dimension is split across workgroups in the y dimension
|
|
@@ -85,6 +103,13 @@ void main() {
|
|
|
85
103
|
float m = data_a[m_offset + k * lm_stride];
|
|
86
104
|
O += exp(m - m_max) * data_a[o_offset];
|
|
87
105
|
}
|
|
106
|
+
if (p.sinks != 0) {
|
|
107
|
+
if (sink > m_max) {
|
|
108
|
+
float ms = 1.0f;
|
|
109
|
+
ms = exp(m_max - sink);
|
|
110
|
+
O *= ms;
|
|
111
|
+
}
|
|
112
|
+
}
|
|
88
113
|
O *= L;
|
|
89
114
|
data_d[iq3 * D * N + D * n + d] = O;
|
|
90
115
|
}
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
#extension GL_EXT_control_flow_attributes : require
|
|
3
3
|
|
|
4
4
|
#include "rte.comp"
|
|
5
|
+
#include "utils.comp"
|
|
5
6
|
|
|
6
7
|
layout (push_constant) uniform parameter
|
|
7
8
|
{
|
|
@@ -28,25 +29,9 @@ uint get_aoffset() { return p.misalign_offsets >> 16; }
|
|
|
28
29
|
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
|
|
29
30
|
uint get_doffset() { return p.misalign_offsets & 0xFF; }
|
|
30
31
|
|
|
31
|
-
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
|
|
32
|
-
uint fastmod(uint a, uint b) {
|
|
33
|
-
if ((b & (b-1)) == 0) {
|
|
34
|
-
return a & (b-1);
|
|
35
|
-
}
|
|
36
|
-
return a % b;
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
uint fastdiv(uint a, uint b) {
|
|
40
|
-
return (a < b) ? 0 : (a / b);
|
|
41
|
-
}
|
|
42
32
|
|
|
43
33
|
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
|
|
44
|
-
|
|
45
|
-
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
|
46
|
-
i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
|
|
47
|
-
const uint i02_offset = i02*p.ne01*p.ne00;
|
|
48
|
-
i01 = (idx - i03_offset - i02_offset) / p.ne00;
|
|
49
|
-
i00 = idx - i03_offset - i02_offset - i01*p.ne00;
|
|
34
|
+
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
|
|
50
35
|
}
|
|
51
36
|
|
|
52
37
|
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
|
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
#extension GL_EXT_control_flow_attributes : enable
|
|
2
2
|
#extension GL_EXT_shader_16bit_storage : require
|
|
3
3
|
#extension GL_EXT_shader_8bit_storage : require
|
|
4
|
+
#if USE_SUBGROUP_ADD
|
|
5
|
+
#extension GL_KHR_shader_subgroup_basic : require
|
|
6
|
+
#extension GL_KHR_shader_subgroup_arithmetic : require
|
|
7
|
+
#endif
|
|
4
8
|
|
|
5
9
|
#ifdef MUL_MAT_ID
|
|
6
10
|
#define EXPERT_COUNT 8
|
|
@@ -90,7 +94,38 @@ layout (constant_id = 2) const uint NUM_COLS = 1;
|
|
|
90
94
|
|
|
91
95
|
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
|
|
92
96
|
|
|
93
|
-
void reduce_result(
|
|
97
|
+
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
|
98
|
+
// subgroupAdd is probably faster on devices that support it,
|
|
99
|
+
// particularly when the workgroup has more than one subgroup
|
|
100
|
+
#if USE_SUBGROUP_ADD
|
|
101
|
+
// sum up partial sums within a subgroup
|
|
102
|
+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
|
103
|
+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
104
|
+
temp[j][n] = subgroupAdd(temp[j][n]);
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// Go through shared memory to sum partials across subgroups
|
|
109
|
+
if (gl_SubgroupInvocationID == 0) {
|
|
110
|
+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
|
111
|
+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
112
|
+
tmpsh[j][n][gl_SubgroupID] = temp[j][n];
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
barrier();
|
|
117
|
+
if (tid == 0) {
|
|
118
|
+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
|
119
|
+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
120
|
+
temp[j][n] = FLOAT_TYPE(0);
|
|
121
|
+
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
|
122
|
+
temp[j][n] += tmpsh[j][n][s];
|
|
123
|
+
}
|
|
124
|
+
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
#else
|
|
94
129
|
// sum up partial sums and write back result
|
|
95
130
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
|
96
131
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
@@ -115,4 +150,5 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32
|
|
|
115
150
|
}
|
|
116
151
|
}
|
|
117
152
|
}
|
|
153
|
+
#endif
|
|
118
154
|
}
|
|
@@ -26,6 +26,9 @@ layout (push_constant) uniform parameter
|
|
|
26
26
|
uint ne12;
|
|
27
27
|
uint b_offset;
|
|
28
28
|
uint d_offset;
|
|
29
|
+
uint nb03;
|
|
30
|
+
uint nb13;
|
|
31
|
+
uint nb23;
|
|
29
32
|
} p;
|
|
30
33
|
|
|
31
34
|
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
|
@@ -34,6 +37,7 @@ void main() {
|
|
|
34
37
|
const uint tid = gl_LocalInvocationID.x;
|
|
35
38
|
const uint row_x = gl_GlobalInvocationID.y;
|
|
36
39
|
const uint channel = gl_GlobalInvocationID.z;
|
|
40
|
+
const uint i3 = gl_WorkGroupID.x;
|
|
37
41
|
const uint channel_x = channel / p.channel_x_divisor;
|
|
38
42
|
const uint channel_y = channel % p.ne12;
|
|
39
43
|
|
|
@@ -41,7 +45,7 @@ void main() {
|
|
|
41
45
|
const uint nrows_dst = p.nrows_x;
|
|
42
46
|
const uint row_dst = row_x;
|
|
43
47
|
|
|
44
|
-
const uint idst = channel*nrows_dst + row_dst;
|
|
48
|
+
const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;
|
|
45
49
|
|
|
46
50
|
FLOAT_TYPE temp = 0.0f;
|
|
47
51
|
|
|
@@ -58,8 +62,8 @@ void main() {
|
|
|
58
62
|
|
|
59
63
|
const uint row_y = col_x;
|
|
60
64
|
|
|
61
|
-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
62
|
-
const uint iy = channel_y*p.channel_stride_y + row_y;
|
|
65
|
+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
66
|
+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
|
|
63
67
|
|
|
64
68
|
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
|
65
69
|
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
|
@@ -74,8 +78,8 @@ void main() {
|
|
|
74
78
|
|
|
75
79
|
const uint row_y = col_x;
|
|
76
80
|
|
|
77
|
-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
78
|
-
const uint iy = channel_y*p.channel_stride_y + row_y;
|
|
81
|
+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
82
|
+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
|
|
79
83
|
|
|
80
84
|
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
|
81
85
|
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
|
@@ -91,8 +95,8 @@ void main() {
|
|
|
91
95
|
|
|
92
96
|
const uint row_y = col_x;
|
|
93
97
|
|
|
94
|
-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
95
|
-
const uint iy = channel_y*p.channel_stride_y + row_y;
|
|
98
|
+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
99
|
+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
|
|
96
100
|
|
|
97
101
|
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
|
98
102
|
|