@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -1,8 +1,20 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
3
|
#include "common.cuh"
|
|
4
|
+
|
|
4
5
|
#include <cstdint>
|
|
5
6
|
|
|
7
|
+
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
|
|
8
|
+
const uint8_t * x8 = (const uint8_t *) x;
|
|
9
|
+
|
|
10
|
+
int x32 = x8[4*i32 + 0] << 0;
|
|
11
|
+
x32 |= x8[4*i32 + 1] << 8;
|
|
12
|
+
x32 |= x8[4*i32 + 2] << 16;
|
|
13
|
+
x32 |= x8[4*i32 + 3] << 24;
|
|
14
|
+
|
|
15
|
+
return x32;
|
|
16
|
+
}
|
|
17
|
+
|
|
6
18
|
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
|
7
19
|
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
|
8
20
|
|
|
@@ -16,6 +28,72 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
|
|
|
16
28
|
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
|
17
29
|
}
|
|
18
30
|
|
|
31
|
+
// q4 contains 8 indices with 4 bit each.
|
|
32
|
+
// This function selects those bytes from table that are at those indices and returns them as int2.
|
|
33
|
+
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
|
|
34
|
+
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
|
|
35
|
+
#if defined(GGML_USE_HIP)
|
|
36
|
+
// Load the 16-byte table into four 32-bit unsigned integers.
|
|
37
|
+
const uint32_t *values = (const uint32_t *)table;
|
|
38
|
+
|
|
39
|
+
const uint32_t q_even = q4;
|
|
40
|
+
const uint32_t q_odd = (q4 >> 4);
|
|
41
|
+
|
|
42
|
+
// Perform lookups in the lower half of the table (indices 0-7).
|
|
43
|
+
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
|
|
44
|
+
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
|
|
45
|
+
|
|
46
|
+
// Perform lookups in the upper half of the table (indices 8-15).
|
|
47
|
+
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
|
|
48
|
+
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
|
|
49
|
+
|
|
50
|
+
// Select between the low and high results based on the MSB of each index nibble.
|
|
51
|
+
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
|
|
52
|
+
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
|
|
53
|
+
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
|
|
54
|
+
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
|
|
55
|
+
|
|
56
|
+
return make_int2(res_x, res_y);
|
|
57
|
+
#elif !defined(GGML_USE_MUSA)
|
|
58
|
+
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
|
|
59
|
+
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
|
|
60
|
+
const uint32_t * table32 = (const uint32_t *) table;
|
|
61
|
+
|
|
62
|
+
// __byte_perm selects bytes based on the lower 16 bits in its third argument.
|
|
63
|
+
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
|
|
64
|
+
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
|
|
65
|
+
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
|
|
66
|
+
uint32_t tmp[2];
|
|
67
|
+
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
|
|
68
|
+
#pragma unroll
|
|
69
|
+
for (uint32_t i = 0; i < 2; ++i) {
|
|
70
|
+
const uint32_t shift = 16 * i;
|
|
71
|
+
|
|
72
|
+
const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
|
|
73
|
+
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
|
|
74
|
+
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
|
|
78
|
+
// However, for the result we need ints with all even/odd 4 bit indices in q4.
|
|
79
|
+
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
|
|
80
|
+
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
|
|
81
|
+
#else
|
|
82
|
+
// Generic implementation.
|
|
83
|
+
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
|
84
|
+
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
|
85
|
+
const char4 val0_8 = make_char4(
|
|
86
|
+
table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
|
|
87
|
+
|
|
88
|
+
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
|
89
|
+
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
|
90
|
+
const char4 val1_8 = make_char4(
|
|
91
|
+
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
|
|
92
|
+
|
|
93
|
+
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
|
94
|
+
#endif
|
|
95
|
+
}
|
|
96
|
+
|
|
19
97
|
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
|
20
98
|
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
|
21
99
|
|
|
@@ -61,7 +139,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
|
|
|
61
139
|
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
|
|
62
140
|
}
|
|
63
141
|
|
|
64
|
-
#ifdef
|
|
142
|
+
#ifdef FAST_FP16_AVAILABLE
|
|
65
143
|
const float2 tmp = __half22float2(__hmul2(dm4, ds8));
|
|
66
144
|
const float d4d8 = tmp.x;
|
|
67
145
|
const float m4s8 = tmp.y;
|
|
@@ -70,7 +148,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
|
|
|
70
148
|
const float2 ds8f = __half22float2(ds8);
|
|
71
149
|
const float d4d8 = dm4f.x * ds8f.x;
|
|
72
150
|
const float m4s8 = dm4f.y * ds8f.y;
|
|
73
|
-
#endif //
|
|
151
|
+
#endif // FAST_FP16_AVAILABLE
|
|
74
152
|
|
|
75
153
|
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
|
|
76
154
|
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
|
|
@@ -132,7 +210,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
|
|
|
132
210
|
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
|
|
133
211
|
}
|
|
134
212
|
|
|
135
|
-
#ifdef
|
|
213
|
+
#ifdef FAST_FP16_AVAILABLE
|
|
136
214
|
const float2 tmp = __half22float2(__hmul2(dm5, ds8));
|
|
137
215
|
const float d5d8 = tmp.x;
|
|
138
216
|
const float m5s8 = tmp.y;
|
|
@@ -141,7 +219,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
|
|
|
141
219
|
const float2 ds8f = __half22float2(ds8);
|
|
142
220
|
const float d5d8 = dm5f.x * ds8f.x;
|
|
143
221
|
const float m5s8 = dm5f.y * ds8f.y;
|
|
144
|
-
#endif //
|
|
222
|
+
#endif // FAST_FP16_AVAILABLE
|
|
145
223
|
|
|
146
224
|
// scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
|
|
147
225
|
return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
|
|
@@ -175,7 +253,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
|
|
175
253
|
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
|
|
176
254
|
}
|
|
177
255
|
|
|
178
|
-
#ifdef
|
|
256
|
+
#ifdef FAST_FP16_AVAILABLE
|
|
179
257
|
const float2 tmp = __half22float2(__hmul2(dm8, ds8));
|
|
180
258
|
const float d8d8 = tmp.x;
|
|
181
259
|
const float m8s8 = tmp.y;
|
|
@@ -184,7 +262,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
|
|
184
262
|
const float2 ds8f = __half22float2(ds8);
|
|
185
263
|
const float d8d8 = dm8f.x * ds8f.x;
|
|
186
264
|
const float m8s8 = dm8f.y * ds8f.y;
|
|
187
|
-
#endif //
|
|
265
|
+
#endif // FAST_FP16_AVAILABLE
|
|
188
266
|
|
|
189
267
|
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
|
|
190
268
|
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
|
|
@@ -211,6 +289,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
|
|
|
211
289
|
return d8_1*sumf;
|
|
212
290
|
}
|
|
213
291
|
|
|
292
|
+
#define VDR_MXFP4_Q8_1_MMVQ 2
|
|
293
|
+
#define VDR_MXFP4_Q8_1_MMQ 4
|
|
294
|
+
|
|
295
|
+
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
|
|
296
|
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
297
|
+
|
|
298
|
+
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
|
|
299
|
+
|
|
300
|
+
const int * q8 = (const int *) bq8_1->qs + iqs;
|
|
301
|
+
|
|
302
|
+
int sumi = 0;
|
|
303
|
+
#pragma unroll
|
|
304
|
+
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
|
|
305
|
+
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
|
|
306
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
|
307
|
+
|
|
308
|
+
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
|
309
|
+
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
|
|
313
|
+
return d * sumi;
|
|
314
|
+
}
|
|
315
|
+
|
|
214
316
|
#define VDR_Q2_K_Q8_1_MMVQ 1
|
|
215
317
|
#define VDR_Q2_K_Q8_1_MMQ 4
|
|
216
318
|
|
|
@@ -1068,20 +1170,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
|
|
1068
1170
|
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
|
|
1069
1171
|
}
|
|
1070
1172
|
|
|
1071
|
-
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
|
1072
|
-
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
|
1073
|
-
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
|
1074
|
-
const char4 val0_8 = make_char4(
|
|
1075
|
-
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
|
|
1076
|
-
|
|
1077
|
-
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
|
1078
|
-
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
|
1079
|
-
const char4 val1_8 = make_char4(
|
|
1080
|
-
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
|
|
1081
|
-
|
|
1082
|
-
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
|
1083
|
-
}
|
|
1084
|
-
|
|
1085
1173
|
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
|
1086
1174
|
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
|
1087
1175
|
|
|
@@ -1096,7 +1184,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
|
|
1096
1184
|
#pragma unroll
|
|
1097
1185
|
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
|
1098
1186
|
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
|
|
1099
|
-
const int2 v = get_int_from_table_16(aux_q4);
|
|
1187
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
|
1100
1188
|
|
|
1101
1189
|
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
|
1102
1190
|
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
|
@@ -1118,7 +1206,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
|
|
1118
1206
|
#pragma unroll
|
|
1119
1207
|
for (int j = 0; j < 4; ++j) {
|
|
1120
1208
|
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
|
|
1121
|
-
const int2 v = get_int_from_table_16(aux_q4);
|
|
1209
|
+
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
|
1122
1210
|
|
|
1123
1211
|
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
|
|
1124
1212
|
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
|
|
@@ -6,6 +6,10 @@
|
|
|
6
6
|
#include <cuda_bf16.h>
|
|
7
7
|
#include <cuda_fp16.h>
|
|
8
8
|
|
|
9
|
+
#if CUDART_VERSION >= 12050
|
|
10
|
+
#include <cuda_fp8.h>
|
|
11
|
+
#endif // CUDART_VERSION >= 12050
|
|
12
|
+
|
|
9
13
|
#if CUDART_VERSION < 11020
|
|
10
14
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
|
11
15
|
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
|
@@ -1,14 +1,10 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
|
-
#define
|
|
3
|
+
#define HIP_DISABLE_WARP_SYNC_BUILTINS 1
|
|
4
4
|
#include <hip/hip_runtime.h>
|
|
5
5
|
#include <hipblas/hipblas.h>
|
|
6
6
|
#include <hip/hip_fp16.h>
|
|
7
|
-
#include <hip/
|
|
8
|
-
#ifdef __HIP_PLATFORM_AMD__
|
|
9
|
-
// for rocblas_initialize()
|
|
10
|
-
#include "rocblas/rocblas.h"
|
|
11
|
-
#endif // __HIP_PLATFORM_AMD__
|
|
7
|
+
#include <hip/hip_bf16.h>
|
|
12
8
|
|
|
13
9
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
|
14
10
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
|
@@ -26,7 +22,10 @@
|
|
|
26
22
|
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
|
27
23
|
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
|
28
24
|
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
|
25
|
+
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
|
|
29
26
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
|
27
|
+
#define __all_sync(mask, var) __all(var)
|
|
28
|
+
#define __any_sync(mask, var) __any(var)
|
|
30
29
|
#define cublasCreate hipblasCreate
|
|
31
30
|
#define cublasDestroy hipblasDestroy
|
|
32
31
|
#define cublasGemmEx hipblasGemmEx
|
|
@@ -139,7 +138,7 @@
|
|
|
139
138
|
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
|
140
139
|
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
|
141
140
|
|
|
142
|
-
#if
|
|
141
|
+
#if HIP_VERSION >= 60500000
|
|
143
142
|
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
|
144
143
|
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
|
145
144
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
|
@@ -151,7 +150,11 @@
|
|
|
151
150
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
|
152
151
|
#define cublasComputeType_t hipblasDatatype_t
|
|
153
152
|
#define cudaDataType_t hipblasDatatype_t
|
|
154
|
-
#endif
|
|
153
|
+
#endif // HIP_VERSION >= 6050000
|
|
154
|
+
|
|
155
|
+
#if !defined(__HIP_PLATFORM_AMD__)
|
|
156
|
+
#error "The HIP backend supports only AMD targets"
|
|
157
|
+
#endif // !defined(__HIP_PLATFORM_AMD__)
|
|
155
158
|
|
|
156
159
|
#define __CUDA_ARCH__ 1300
|
|
157
160
|
|
|
@@ -179,8 +182,7 @@
|
|
|
179
182
|
#define RDNA4
|
|
180
183
|
#endif
|
|
181
184
|
|
|
182
|
-
#if defined(
|
|
183
|
-
defined(__gfx1150__) || defined(__gfx1151__)
|
|
185
|
+
#if defined(__GFX11__)
|
|
184
186
|
#define RDNA3
|
|
185
187
|
#endif
|
|
186
188
|
|
|
@@ -197,7 +199,8 @@
|
|
|
197
199
|
#define __has_builtin(x) 0
|
|
198
200
|
#endif
|
|
199
201
|
|
|
200
|
-
typedef
|
|
202
|
+
typedef __hip_bfloat16 nv_bfloat16;
|
|
203
|
+
typedef __hip_bfloat162 nv_bfloat162;
|
|
201
204
|
|
|
202
205
|
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
|
203
206
|
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
|
@@ -248,17 +251,3 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
|
|
|
248
251
|
}
|
|
249
252
|
return c;
|
|
250
253
|
}
|
|
251
|
-
|
|
252
|
-
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
|
253
|
-
// __shfl_xor() for half2 was added in ROCm 5.6
|
|
254
|
-
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
|
|
255
|
-
typedef union half2_b32 {
|
|
256
|
-
half2 val;
|
|
257
|
-
int b32;
|
|
258
|
-
} half2_b32_t;
|
|
259
|
-
half2_b32_t tmp;
|
|
260
|
-
tmp.val = var;
|
|
261
|
-
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
|
|
262
|
-
return tmp.val;
|
|
263
|
-
}
|
|
264
|
-
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
|
@@ -137,4 +137,5 @@
|
|
|
137
137
|
#define cudaStreamEndCapture musaStreamEndCapture
|
|
138
138
|
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
|
139
139
|
|
|
140
|
-
typedef
|
|
140
|
+
typedef __mt_bfloat16 nv_bfloat16;
|
|
141
|
+
typedef __mt_bfloat162 nv_bfloat162;
|
|
@@ -46,8 +46,8 @@ if (GGML_HIP_ROCWMMA_FATTN)
|
|
|
46
46
|
endif()
|
|
47
47
|
endif()
|
|
48
48
|
|
|
49
|
-
if (${hip_VERSION} VERSION_LESS
|
|
50
|
-
message(FATAL_ERROR "At least ROCM/HIP
|
|
49
|
+
if (${hip_VERSION} VERSION_LESS 6.1)
|
|
50
|
+
message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
|
|
51
51
|
endif()
|
|
52
52
|
|
|
53
53
|
message(STATUS "HIP and hipBLAS found")
|
|
@@ -113,10 +113,18 @@ if (GGML_HIP_ROCWMMA_FATTN)
|
|
|
113
113
|
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
|
|
114
114
|
endif()
|
|
115
115
|
|
|
116
|
+
if (NOT GGML_HIP_MMQ_MFMA)
|
|
117
|
+
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
|
|
118
|
+
endif()
|
|
119
|
+
|
|
116
120
|
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
|
|
117
121
|
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
|
118
122
|
endif()
|
|
119
123
|
|
|
124
|
+
if (GGML_HIP_EXPORT_METRICS)
|
|
125
|
+
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
|
|
126
|
+
endif()
|
|
127
|
+
|
|
120
128
|
if (NOT GGML_CUDA_FA)
|
|
121
129
|
add_compile_definitions(GGML_CUDA_NO_FA)
|
|
122
130
|
endif()
|
|
@@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|
|
410
410
|
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
|
411
411
|
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
|
412
412
|
|
|
413
|
+
static inline float ggml_e8m0_to_fp32(uint8_t x) {
|
|
414
|
+
uint32_t bits; // Stores the raw bit representation of the float
|
|
415
|
+
|
|
416
|
+
// Handle special case for minimum exponent (denormalized float)
|
|
417
|
+
if (x == 0) {
|
|
418
|
+
// Bit pattern for 2^(-127):
|
|
419
|
+
// - Sign bit: 0 (positive)
|
|
420
|
+
// - Exponent: 0 (denormalized number)
|
|
421
|
+
// - Mantissa: 0x400000 (0.5 in fractional form)
|
|
422
|
+
// Value = 0.5 * 2^(-126) = 2^(-127)
|
|
423
|
+
bits = 0x00400000;
|
|
424
|
+
}
|
|
425
|
+
// note: disabled as we don't need to handle NaNs
|
|
426
|
+
//// Handle special case for NaN (all bits set)
|
|
427
|
+
//else if (x == 0xFF) {
|
|
428
|
+
// // Standard quiet NaN pattern:
|
|
429
|
+
// // - Sign bit: 0
|
|
430
|
+
// // - Exponent: all 1s (0xFF)
|
|
431
|
+
// // - Mantissa: 0x400000 (quiet NaN flag)
|
|
432
|
+
// bits = 0x7FC00000;
|
|
433
|
+
//}
|
|
434
|
+
// Normalized values (most common case)
|
|
435
|
+
else {
|
|
436
|
+
// Construct normalized float by shifting exponent into position:
|
|
437
|
+
// - Exponent field: 8 bits (positions 30-23)
|
|
438
|
+
// - Mantissa: 0 (implicit leading 1)
|
|
439
|
+
// Value = 2^(x - 127)
|
|
440
|
+
bits = (uint32_t) x << 23;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
float result; // Final float value
|
|
444
|
+
// Safely reinterpret bit pattern as float without type-punning issues
|
|
445
|
+
memcpy(&result, &bits, sizeof(float));
|
|
446
|
+
return result;
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
// Equal to ggml_e8m0_to_fp32/2
|
|
450
|
+
// Useful with MXFP4 quantization since the E0M2 values are doubled
|
|
451
|
+
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
|
|
452
|
+
uint32_t bits;
|
|
453
|
+
|
|
454
|
+
// For x < 2: use precomputed denormal patterns
|
|
455
|
+
if (x < 2) {
|
|
456
|
+
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
|
|
457
|
+
bits = 0x00200000 << x;
|
|
458
|
+
}
|
|
459
|
+
// For x >= 2: normalized exponent adjustment
|
|
460
|
+
else {
|
|
461
|
+
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
|
|
462
|
+
bits = (uint32_t)(x - 1) << 23;
|
|
463
|
+
}
|
|
464
|
+
// Note: NaNs are not handled here
|
|
465
|
+
|
|
466
|
+
float result;
|
|
467
|
+
memcpy(&result, &bits, sizeof(float));
|
|
468
|
+
return result;
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
|
|
472
|
+
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
|
|
473
|
+
|
|
413
474
|
/**
|
|
414
475
|
* Converts brain16 to float32.
|
|
415
476
|
*
|
|
@@ -23,6 +23,9 @@
|
|
|
23
23
|
#define N_R0_Q8_0 4
|
|
24
24
|
#define N_SG_Q8_0 2
|
|
25
25
|
|
|
26
|
+
#define N_R0_MXFP4 2
|
|
27
|
+
#define N_SG_MXFP4 2
|
|
28
|
+
|
|
26
29
|
#define N_R0_Q2_K 4
|
|
27
30
|
#define N_SG_Q2_K 2
|
|
28
31
|
|
|
@@ -129,6 +132,15 @@ typedef struct {
|
|
|
129
132
|
uint64_t o1[8];
|
|
130
133
|
} ggml_metal_kargs_bin;
|
|
131
134
|
|
|
135
|
+
typedef struct {
|
|
136
|
+
int64_t ne0;
|
|
137
|
+
int64_t ne1;
|
|
138
|
+
size_t nb01;
|
|
139
|
+
size_t nb02;
|
|
140
|
+
size_t nb11;
|
|
141
|
+
size_t nb21;
|
|
142
|
+
} ggml_metal_kargs_add_id;
|
|
143
|
+
|
|
132
144
|
typedef struct {
|
|
133
145
|
int32_t ne00;
|
|
134
146
|
int32_t ne01;
|
|
@@ -237,6 +249,7 @@ typedef struct {
|
|
|
237
249
|
uint64_t nb33;
|
|
238
250
|
int32_t ne1;
|
|
239
251
|
int32_t ne2;
|
|
252
|
+
int32_t ne3;
|
|
240
253
|
float scale;
|
|
241
254
|
float max_bias;
|
|
242
255
|
float m0;
|
|
@@ -245,6 +258,11 @@ typedef struct {
|
|
|
245
258
|
float logit_softcap;
|
|
246
259
|
} ggml_metal_kargs_flash_attn_ext;
|
|
247
260
|
|
|
261
|
+
typedef struct {
|
|
262
|
+
int32_t nrows;
|
|
263
|
+
int32_t ne20;
|
|
264
|
+
} ggml_metal_kargs_flash_attn_ext_reduce;
|
|
265
|
+
|
|
248
266
|
typedef struct {
|
|
249
267
|
int32_t ne00;
|
|
250
268
|
int32_t ne02;
|
|
@@ -308,40 +326,31 @@ typedef struct {
|
|
|
308
326
|
} ggml_metal_kargs_mul_mv_ext;
|
|
309
327
|
|
|
310
328
|
typedef struct {
|
|
329
|
+
int32_t ne02;
|
|
311
330
|
int32_t ne10;
|
|
312
331
|
int32_t ne11; // n_expert_used (bcast)
|
|
313
332
|
uint64_t nb11;
|
|
314
333
|
uint64_t nb12;
|
|
315
|
-
int32_t
|
|
316
|
-
uint64_t nbh11;
|
|
334
|
+
int32_t ne21; // n_tokens
|
|
317
335
|
int32_t ne20; // n_expert_used
|
|
318
336
|
uint64_t nb21;
|
|
319
337
|
} ggml_metal_kargs_mul_mm_id_map0;
|
|
320
338
|
|
|
321
|
-
typedef struct {
|
|
322
|
-
int32_t ne20; // n_expert_used
|
|
323
|
-
int32_t neh0;
|
|
324
|
-
int32_t neh1;
|
|
325
|
-
uint64_t nbh1;
|
|
326
|
-
uint64_t nbh2;
|
|
327
|
-
int32_t ne0;
|
|
328
|
-
uint64_t nb1;
|
|
329
|
-
uint64_t nb2;
|
|
330
|
-
} ggml_metal_kargs_mul_mm_id_map1;
|
|
331
|
-
|
|
332
339
|
typedef struct {
|
|
333
340
|
int32_t ne00;
|
|
334
341
|
int32_t ne02;
|
|
335
342
|
uint64_t nb01;
|
|
336
343
|
uint64_t nb02;
|
|
337
344
|
uint64_t nb03;
|
|
338
|
-
int32_t
|
|
339
|
-
uint64_t
|
|
340
|
-
uint64_t
|
|
341
|
-
uint64_t
|
|
342
|
-
uint64_t
|
|
343
|
-
int32_t
|
|
344
|
-
int32_t
|
|
345
|
+
int32_t ne11;
|
|
346
|
+
uint64_t nb10;
|
|
347
|
+
uint64_t nb11;
|
|
348
|
+
uint64_t nb12;
|
|
349
|
+
uint64_t nb13;
|
|
350
|
+
int32_t ne20;
|
|
351
|
+
int32_t ne21;
|
|
352
|
+
int32_t ne0;
|
|
353
|
+
int32_t ne1;
|
|
345
354
|
int16_t r2;
|
|
346
355
|
int16_t r3;
|
|
347
356
|
} ggml_metal_kargs_mul_mm_id;
|
|
@@ -444,6 +453,8 @@ typedef struct{
|
|
|
444
453
|
uint64_t nb1;
|
|
445
454
|
int32_t i00;
|
|
446
455
|
int32_t i10;
|
|
456
|
+
float alpha;
|
|
457
|
+
float limit;
|
|
447
458
|
} ggml_metal_kargs_glu;
|
|
448
459
|
|
|
449
460
|
typedef struct {
|