@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#include "vec.h"
|
|
9
9
|
|
|
10
10
|
#include <float.h>
|
|
11
|
+
#include <algorithm>
|
|
11
12
|
|
|
12
13
|
// ggml_compute_forward_dup
|
|
13
14
|
|
|
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
|
|
|
1283
1284
|
case GGML_TYPE_Q5_0:
|
|
1284
1285
|
case GGML_TYPE_Q5_1:
|
|
1285
1286
|
case GGML_TYPE_Q8_0:
|
|
1287
|
+
case GGML_TYPE_MXFP4:
|
|
1286
1288
|
case GGML_TYPE_Q2_K:
|
|
1287
1289
|
case GGML_TYPE_Q3_K:
|
|
1288
1290
|
case GGML_TYPE_Q4_K:
|
|
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
|
|
|
1309
1311
|
}
|
|
1310
1312
|
}
|
|
1311
1313
|
|
|
1314
|
+
// ggml_compute_forward_add_id
|
|
1315
|
+
|
|
1316
|
+
static void ggml_compute_forward_add_id_f32(
|
|
1317
|
+
const ggml_compute_params * params,
|
|
1318
|
+
ggml_tensor * dst) {
|
|
1319
|
+
|
|
1320
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1321
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
1322
|
+
const ggml_tensor * src2 = dst->src[2];
|
|
1323
|
+
|
|
1324
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
1325
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
1326
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
1327
|
+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
|
1328
|
+
|
|
1329
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
1330
|
+
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
1331
|
+
|
|
1332
|
+
const int ith = params->ith;
|
|
1333
|
+
const int nth = params->nth;
|
|
1334
|
+
|
|
1335
|
+
const int nr = ggml_nrows(src0);
|
|
1336
|
+
|
|
1337
|
+
GGML_TENSOR_TERNARY_OP_LOCALS
|
|
1338
|
+
|
|
1339
|
+
GGML_ASSERT( nb0 == sizeof(float));
|
|
1340
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
|
1341
|
+
|
|
1342
|
+
// rows per thread
|
|
1343
|
+
const int dr = (nr + nth - 1)/nth;
|
|
1344
|
+
|
|
1345
|
+
// row range for this thread
|
|
1346
|
+
const int ir0 = dr*ith;
|
|
1347
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
1348
|
+
|
|
1349
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
1350
|
+
// src0 indices
|
|
1351
|
+
const int i3 = ir/(ne2*ne1);
|
|
1352
|
+
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
|
1353
|
+
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
|
1354
|
+
|
|
1355
|
+
// src1 indices
|
|
1356
|
+
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
|
|
1357
|
+
|
|
1358
|
+
GGML_ASSERT(i11 >= 0 && i11 < ne11);
|
|
1359
|
+
|
|
1360
|
+
ggml_vec_add_f32(ne0,
|
|
1361
|
+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
|
1362
|
+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
|
1363
|
+
(float *) ((char *) src1->data + i11*nb11));
|
|
1364
|
+
}
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
void ggml_compute_forward_add_id(
|
|
1368
|
+
const ggml_compute_params * params,
|
|
1369
|
+
ggml_tensor * dst) {
|
|
1370
|
+
|
|
1371
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1372
|
+
|
|
1373
|
+
switch (src0->type) {
|
|
1374
|
+
case GGML_TYPE_F32:
|
|
1375
|
+
{
|
|
1376
|
+
ggml_compute_forward_add_id_f32(params, dst);
|
|
1377
|
+
} break;
|
|
1378
|
+
default:
|
|
1379
|
+
{
|
|
1380
|
+
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
|
|
1381
|
+
}
|
|
1382
|
+
}
|
|
1383
|
+
}
|
|
1384
|
+
|
|
1312
1385
|
// ggml_compute_forward_add1
|
|
1313
1386
|
|
|
1314
1387
|
static void ggml_compute_forward_add1_f32(
|
|
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
|
|
|
1660
1733
|
case GGML_TYPE_Q5_1:
|
|
1661
1734
|
case GGML_TYPE_Q8_0:
|
|
1662
1735
|
case GGML_TYPE_Q8_1:
|
|
1736
|
+
case GGML_TYPE_MXFP4:
|
|
1663
1737
|
case GGML_TYPE_Q2_K:
|
|
1664
1738
|
case GGML_TYPE_Q3_K:
|
|
1665
1739
|
case GGML_TYPE_Q4_K:
|
|
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
|
|
|
1787
1861
|
case GGML_TYPE_Q5_1:
|
|
1788
1862
|
case GGML_TYPE_Q8_0:
|
|
1789
1863
|
case GGML_TYPE_Q8_1:
|
|
1864
|
+
case GGML_TYPE_MXFP4:
|
|
1790
1865
|
case GGML_TYPE_Q2_K:
|
|
1791
1866
|
case GGML_TYPE_Q3_K:
|
|
1792
1867
|
case GGML_TYPE_Q4_K:
|
|
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
|
|
|
3614
3689
|
}
|
|
3615
3690
|
}
|
|
3616
3691
|
|
|
3692
|
+
// ggml_compute_forward_swiglu_oai
|
|
3693
|
+
|
|
3694
|
+
static void ggml_compute_forward_swiglu_oai_f32(
|
|
3695
|
+
const ggml_compute_params * params,
|
|
3696
|
+
ggml_tensor * dst) {
|
|
3697
|
+
|
|
3698
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3699
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
3700
|
+
char * src0_d = (char *) src0->data;
|
|
3701
|
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
|
3702
|
+
const size_t src0_o = src0->nb[1];
|
|
3703
|
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
|
3704
|
+
|
|
3705
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
3706
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
|
3707
|
+
|
|
3708
|
+
if (src1) {
|
|
3709
|
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
|
3710
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
3711
|
+
}
|
|
3712
|
+
|
|
3713
|
+
const int ith = params->ith;
|
|
3714
|
+
const int nth = params->nth;
|
|
3715
|
+
|
|
3716
|
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
|
3717
|
+
const int nr = ggml_nrows(src0);
|
|
3718
|
+
|
|
3719
|
+
GGML_ASSERT(dst->ne[0] == nc);
|
|
3720
|
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
|
3721
|
+
|
|
3722
|
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
|
3723
|
+
const float alpha = ggml_get_op_params_f32(dst, 2);
|
|
3724
|
+
const float limit = ggml_get_op_params_f32(dst, 3);
|
|
3725
|
+
|
|
3726
|
+
// rows per thread
|
|
3727
|
+
const int dr = (nr + nth - 1)/nth;
|
|
3728
|
+
|
|
3729
|
+
// row range for this thread
|
|
3730
|
+
const int ir0 = dr*ith;
|
|
3731
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
3732
|
+
|
|
3733
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
3734
|
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
|
3735
|
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
|
3736
|
+
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
|
|
3737
|
+
|
|
3738
|
+
if (!src1) {
|
|
3739
|
+
src0_p += swapped ? nc : 0;
|
|
3740
|
+
src1_p += swapped ? 0 : nc;
|
|
3741
|
+
}
|
|
3742
|
+
|
|
3743
|
+
for (int k = 0; k < nc; k++) {
|
|
3744
|
+
const float x = std::min(src0_p[k], limit);
|
|
3745
|
+
const float y = std::clamp(src1_p[k], -limit, limit);
|
|
3746
|
+
const float out_glu = x / (1.f + expf(alpha * (-x)));
|
|
3747
|
+
dst_p[k] = out_glu * (y + 1.f);
|
|
3748
|
+
}
|
|
3749
|
+
|
|
3750
|
+
#ifndef NDEBUG
|
|
3751
|
+
for (int k = 0; k < nc; k++) {
|
|
3752
|
+
const float x = dst_p[k];
|
|
3753
|
+
GGML_UNUSED(x);
|
|
3754
|
+
assert(!isnan(x));
|
|
3755
|
+
assert(!isinf(x));
|
|
3756
|
+
}
|
|
3757
|
+
#endif
|
|
3758
|
+
}
|
|
3759
|
+
}
|
|
3760
|
+
|
|
3761
|
+
static void ggml_compute_forward_swiglu_oai(
|
|
3762
|
+
const ggml_compute_params * params,
|
|
3763
|
+
ggml_tensor * dst) {
|
|
3764
|
+
|
|
3765
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
3766
|
+
|
|
3767
|
+
switch (src0->type) {
|
|
3768
|
+
case GGML_TYPE_F32:
|
|
3769
|
+
{
|
|
3770
|
+
ggml_compute_forward_swiglu_oai_f32(params, dst);
|
|
3771
|
+
} break;
|
|
3772
|
+
default:
|
|
3773
|
+
{
|
|
3774
|
+
GGML_ABORT("fatal error");
|
|
3775
|
+
}
|
|
3776
|
+
}
|
|
3777
|
+
}
|
|
3778
|
+
|
|
3617
3779
|
// ggml_compute_forward_geglu_erf
|
|
3618
3780
|
|
|
3619
3781
|
static void ggml_compute_forward_geglu_erf_f32(
|
|
@@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
|
|
|
4599
4761
|
case GGML_TYPE_Q5_0:
|
|
4600
4762
|
case GGML_TYPE_Q5_1:
|
|
4601
4763
|
case GGML_TYPE_Q8_0:
|
|
4764
|
+
case GGML_TYPE_MXFP4:
|
|
4602
4765
|
case GGML_TYPE_Q2_K:
|
|
4603
4766
|
case GGML_TYPE_Q3_K:
|
|
4604
4767
|
case GGML_TYPE_Q4_K:
|
|
@@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
|
|
|
4873
5036
|
case GGML_TYPE_Q5_1:
|
|
4874
5037
|
case GGML_TYPE_Q8_0:
|
|
4875
5038
|
case GGML_TYPE_Q8_1:
|
|
5039
|
+
case GGML_TYPE_MXFP4:
|
|
4876
5040
|
case GGML_TYPE_Q2_K:
|
|
4877
5041
|
case GGML_TYPE_Q3_K:
|
|
4878
5042
|
case GGML_TYPE_Q4_K:
|
|
@@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
|
|
|
5134
5298
|
case GGML_TYPE_Q5_1:
|
|
5135
5299
|
case GGML_TYPE_Q8_0:
|
|
5136
5300
|
case GGML_TYPE_Q8_1:
|
|
5301
|
+
case GGML_TYPE_MXFP4:
|
|
5137
5302
|
case GGML_TYPE_Q2_K:
|
|
5138
5303
|
case GGML_TYPE_Q3_K:
|
|
5139
5304
|
case GGML_TYPE_Q4_K:
|
|
@@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5523
5688
|
|
|
5524
5689
|
const ggml_tensor * src0 = dst->src[0];
|
|
5525
5690
|
const ggml_tensor * src1 = dst->src[1];
|
|
5691
|
+
const ggml_tensor * src2 = dst->src[2];
|
|
5526
5692
|
|
|
5527
5693
|
assert(ggml_is_contiguous(dst));
|
|
5528
5694
|
assert(ggml_are_same_shape(src0, dst));
|
|
@@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5557
5723
|
|
|
5558
5724
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
5559
5725
|
|
|
5726
|
+
// sinks
|
|
5727
|
+
const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
|
|
5728
|
+
|
|
5560
5729
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
5561
5730
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
5562
5731
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
@@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
|
|
|
5599
5768
|
float max = -INFINITY;
|
|
5600
5769
|
ggml_vec_max_f32(ne00, &max, wp);
|
|
5601
5770
|
|
|
5771
|
+
// if we have sinks, make a correction as if they were included in the softmax
|
|
5772
|
+
if (sk) {
|
|
5773
|
+
max = MAX(max, sk[i02]);
|
|
5774
|
+
}
|
|
5775
|
+
|
|
5602
5776
|
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
|
5603
5777
|
assert(sum > 0.0);
|
|
5604
5778
|
|
|
5779
|
+
if (sk) {
|
|
5780
|
+
sum += (ggml_float) expf(sk[i02] - max);
|
|
5781
|
+
}
|
|
5782
|
+
|
|
5605
5783
|
sum = 1.0/sum;
|
|
5606
5784
|
ggml_vec_scale_f32(ne00, dp, sum);
|
|
5607
5785
|
|
|
@@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
|
|
|
5836
6014
|
case GGML_TYPE_Q5_1:
|
|
5837
6015
|
case GGML_TYPE_Q8_0:
|
|
5838
6016
|
case GGML_TYPE_Q8_1:
|
|
6017
|
+
case GGML_TYPE_MXFP4:
|
|
5839
6018
|
case GGML_TYPE_Q2_K:
|
|
5840
6019
|
case GGML_TYPE_Q3_K:
|
|
5841
6020
|
case GGML_TYPE_Q4_K:
|
|
@@ -7028,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
|
|
|
7028
7207
|
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
|
7029
7208
|
}
|
|
7030
7209
|
|
|
7210
|
+
// ggml_compute_forward_conv_3d
|
|
7211
|
+
|
|
7212
|
+
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
|
|
7213
|
+
const ggml_tensor * kernel,
|
|
7214
|
+
const ggml_tensor * src,
|
|
7215
|
+
ggml_tensor * dst,
|
|
7216
|
+
ggml_type kernel_type) {
|
|
7217
|
+
|
|
7218
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
|
7219
|
+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
|
7220
|
+
GGML_ASSERT(kernel->type == kernel_type);
|
|
7221
|
+
|
|
7222
|
+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
|
7223
|
+
|
|
7224
|
+
const int32_t s0 = dst->op_params[0];
|
|
7225
|
+
const int32_t s1 = dst->op_params[1];
|
|
7226
|
+
const int32_t s2 = dst->op_params[2];
|
|
7227
|
+
const int32_t p0 = dst->op_params[3];
|
|
7228
|
+
const int32_t p1 = dst->op_params[4];
|
|
7229
|
+
const int32_t p2 = dst->op_params[5];
|
|
7230
|
+
const int32_t d0 = dst->op_params[6];
|
|
7231
|
+
const int32_t d1 = dst->op_params[7];
|
|
7232
|
+
const int32_t d2 = dst->op_params[8];
|
|
7233
|
+
const int32_t c = dst->op_params[9];
|
|
7234
|
+
const int32_t n = dst->op_params[10];
|
|
7235
|
+
const int32_t oc = dst->op_params[11];
|
|
7236
|
+
|
|
7237
|
+
const int64_t src_w = src->ne[0];
|
|
7238
|
+
const int64_t src_h = src->ne[1];
|
|
7239
|
+
const int64_t src_d = src->ne[2];
|
|
7240
|
+
const int64_t knl_w = kernel->ne[0];
|
|
7241
|
+
const int64_t knl_h = kernel->ne[1];
|
|
7242
|
+
const int64_t knl_d = kernel->ne[2];
|
|
7243
|
+
const int64_t dst_w = dst->ne[0];
|
|
7244
|
+
const int64_t dst_h = dst->ne[1];
|
|
7245
|
+
const int64_t dst_d = dst->ne[2];
|
|
7246
|
+
|
|
7247
|
+
const float * src_data = (float *) src->data;
|
|
7248
|
+
void * knl_data = kernel->data;
|
|
7249
|
+
float * dst_data = (float *) dst->data;
|
|
7250
|
+
|
|
7251
|
+
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
|
7252
|
+
const int64_t knl_n_total = knl_n_per_channel * c;
|
|
7253
|
+
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
|
7254
|
+
|
|
7255
|
+
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
|
7256
|
+
const int64_t batch_size = params->wsize / space_per_patch;
|
|
7257
|
+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
|
7258
|
+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
|
7259
|
+
|
|
7260
|
+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
|
7261
|
+
|
|
7262
|
+
void * tmp = params->wdata;
|
|
7263
|
+
|
|
7264
|
+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
|
7265
|
+
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
|
7266
|
+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
|
7267
|
+
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
|
7268
|
+
|
|
7269
|
+
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
|
7270
|
+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
|
7271
|
+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
|
7272
|
+
|
|
7273
|
+
for (int64_t p = patch_start; p < patch_end; ++p) {
|
|
7274
|
+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
|
7275
|
+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
|
7276
|
+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
|
7277
|
+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
|
7278
|
+
const int64_t dst_y = p_in_depth / dst_w;
|
|
7279
|
+
const int64_t dst_x = p_in_depth % dst_w;
|
|
7280
|
+
|
|
7281
|
+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
|
7282
|
+
|
|
7283
|
+
for (int64_t ic = 0; ic < c; ++ic) {
|
|
7284
|
+
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
|
7285
|
+
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
|
7286
|
+
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
|
7287
|
+
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
|
7288
|
+
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
|
7289
|
+
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
|
7290
|
+
|
|
7291
|
+
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
|
7292
|
+
|
|
7293
|
+
float src_val;
|
|
7294
|
+
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
|
7295
|
+
src_val = 0.0f;
|
|
7296
|
+
} else {
|
|
7297
|
+
const int64_t cn_idx = batch_idx * c + ic;
|
|
7298
|
+
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
|
7299
|
+
src_val = *src_ptr;
|
|
7300
|
+
}
|
|
7301
|
+
|
|
7302
|
+
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
|
7303
|
+
if (kernel_type == GGML_TYPE_F32) {
|
|
7304
|
+
*(float *)element_ptr = src_val;
|
|
7305
|
+
} else if (kernel_type == GGML_TYPE_F16) {
|
|
7306
|
+
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
|
7307
|
+
}
|
|
7308
|
+
}
|
|
7309
|
+
}
|
|
7310
|
+
}
|
|
7311
|
+
}
|
|
7312
|
+
}
|
|
7313
|
+
|
|
7314
|
+
ggml_barrier(params->threadpool);
|
|
7315
|
+
|
|
7316
|
+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
|
7317
|
+
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
|
7318
|
+
|
|
7319
|
+
ggml_barrier(params->threadpool);
|
|
7320
|
+
|
|
7321
|
+
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
|
7322
|
+
const int64_t permute_start = params->ith * permute_per_thread;
|
|
7323
|
+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
|
7324
|
+
|
|
7325
|
+
for (int64_t i = permute_start; i < permute_end; ++i) {
|
|
7326
|
+
const int64_t p = patch_start_batch + i;
|
|
7327
|
+
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
|
7328
|
+
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
|
7329
|
+
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
|
7330
|
+
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
|
7331
|
+
const int64_t dst_y = p_in_depth / dst_w;
|
|
7332
|
+
const int64_t dst_x = p_in_depth % dst_w;
|
|
7333
|
+
|
|
7334
|
+
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
|
7335
|
+
const float value = gemm_output[i * oc + ioc];
|
|
7336
|
+
const int64_t ocn_idx = batch_idx * oc + ioc;
|
|
7337
|
+
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
|
7338
|
+
*dst_ptr = value;
|
|
7339
|
+
}
|
|
7340
|
+
}
|
|
7341
|
+
}
|
|
7342
|
+
}
|
|
7343
|
+
|
|
7344
|
+
void ggml_compute_forward_conv_3d(
|
|
7345
|
+
const ggml_compute_params * params,
|
|
7346
|
+
ggml_tensor * dst) {
|
|
7347
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
7348
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
7349
|
+
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
|
7350
|
+
}
|
|
7351
|
+
|
|
7031
7352
|
// ggml_compute_forward_conv_transpose_2d
|
|
7032
7353
|
|
|
7033
7354
|
void ggml_compute_forward_conv_transpose_2d(
|
|
@@ -7989,12 +8310,14 @@ void ggml_compute_forward_argsort(
|
|
|
7989
8310
|
|
|
7990
8311
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
|
7991
8312
|
const ggml_compute_params * params,
|
|
7992
|
-
const ggml_tensor * q,
|
|
7993
|
-
const ggml_tensor * k,
|
|
7994
|
-
const ggml_tensor * v,
|
|
7995
|
-
const ggml_tensor * mask,
|
|
7996
8313
|
ggml_tensor * dst) {
|
|
7997
8314
|
|
|
8315
|
+
const ggml_tensor * q = dst->src[0];
|
|
8316
|
+
const ggml_tensor * k = dst->src[1];
|
|
8317
|
+
const ggml_tensor * v = dst->src[2];
|
|
8318
|
+
const ggml_tensor * mask = dst->src[3];
|
|
8319
|
+
const ggml_tensor * sinks = dst->src[4];
|
|
8320
|
+
|
|
7998
8321
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
7999
8322
|
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8000
8323
|
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
@@ -8189,6 +8512,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8189
8512
|
}
|
|
8190
8513
|
}
|
|
8191
8514
|
|
|
8515
|
+
// sinks
|
|
8516
|
+
if (sinks) {
|
|
8517
|
+
const float s = ((float *)((char *) sinks->data))[h];
|
|
8518
|
+
|
|
8519
|
+
float ms = 1.0f;
|
|
8520
|
+
float vs = 1.0f;
|
|
8521
|
+
|
|
8522
|
+
if (s > M) {
|
|
8523
|
+
ms = expf(M - s);
|
|
8524
|
+
ggml_vec_scale_f32(DV, VKQ32, ms);
|
|
8525
|
+
} else {
|
|
8526
|
+
vs = expf(s - M);
|
|
8527
|
+
}
|
|
8528
|
+
|
|
8529
|
+
S = S*ms + vs;
|
|
8530
|
+
}
|
|
8531
|
+
|
|
8192
8532
|
// V /= S
|
|
8193
8533
|
const float S_inv = 1.0f/S;
|
|
8194
8534
|
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
@@ -8208,17 +8548,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8208
8548
|
|
|
8209
8549
|
void ggml_compute_forward_flash_attn_ext(
|
|
8210
8550
|
const ggml_compute_params * params,
|
|
8211
|
-
const ggml_tensor * q,
|
|
8212
|
-
const ggml_tensor * k,
|
|
8213
|
-
const ggml_tensor * v,
|
|
8214
|
-
const ggml_tensor * mask,
|
|
8215
8551
|
ggml_tensor * dst) {
|
|
8216
8552
|
switch (dst->op_params[3]) {
|
|
8217
8553
|
case GGML_PREC_DEFAULT:
|
|
8218
8554
|
case GGML_PREC_F32:
|
|
8219
8555
|
{
|
|
8220
8556
|
// uses F32 accumulators
|
|
8221
|
-
ggml_compute_forward_flash_attn_ext_f16(params,
|
|
8557
|
+
ggml_compute_forward_flash_attn_ext_f16(params, dst);
|
|
8222
8558
|
} break;
|
|
8223
8559
|
default:
|
|
8224
8560
|
{
|
|
@@ -8667,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8667
9003
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
|
8668
9004
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
|
8669
9005
|
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
8670
|
-
|
|
8671
|
-
GGML_ASSERT((ng & -ng) == ng);
|
|
9006
|
+
GGML_ASSERT(nh % ng == 0);
|
|
8672
9007
|
|
|
8673
9008
|
// heads per thread
|
|
8674
9009
|
const int dh = (nh + nth - 1)/nth;
|
|
@@ -8699,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8699
9034
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8700
9035
|
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
8701
9036
|
const float dA = expf(dt_soft_plus * A[h]);
|
|
9037
|
+
const int g = h / (nh / ng); // repeat_interleave
|
|
8702
9038
|
|
|
8703
9039
|
// dim
|
|
8704
9040
|
for (int i1 = 0; i1 < nr; ++i1) {
|
|
@@ -8721,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8721
9057
|
// TODO: maybe unroll more?
|
|
8722
9058
|
for (int j = 0; j < 1; j++) {
|
|
8723
9059
|
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
|
8724
|
-
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr +
|
|
8725
|
-
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr +
|
|
9060
|
+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
|
|
9061
|
+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
|
|
8726
9062
|
|
|
8727
9063
|
t0 = GGML_F32_VEC_MUL(t0, adA);
|
|
8728
9064
|
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
|
@@ -8736,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8736
9072
|
}
|
|
8737
9073
|
|
|
8738
9074
|
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
|
9075
|
+
#elif defined(__riscv_v_intrinsic)
|
|
9076
|
+
// todo: RVV implementation
|
|
9077
|
+
const int np = 0;
|
|
8739
9078
|
#else
|
|
8740
9079
|
const int np = (nc & ~(GGML_F32_STEP - 1));
|
|
8741
9080
|
|
|
@@ -8751,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8751
9090
|
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
|
8752
9091
|
for (int j = 0; j < GGML_F32_ARR; j++) {
|
|
8753
9092
|
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
|
8754
|
-
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR +
|
|
8755
|
-
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR +
|
|
9093
|
+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
|
|
9094
|
+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
|
|
8756
9095
|
|
|
8757
9096
|
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
|
8758
9097
|
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
|
@@ -8774,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8774
9113
|
// d_state
|
|
8775
9114
|
for (int i0 = np; i0 < nc; ++i0) {
|
|
8776
9115
|
const int i = i0 + ii*nc;
|
|
8777
|
-
const int ig = i0 +
|
|
9116
|
+
const int ig = i0 + g*nc;
|
|
8778
9117
|
// state = prev_state * dA + dB * x
|
|
8779
9118
|
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
|
8780
9119
|
// y = rowwise_dotprod(state, C)
|
|
@@ -8791,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8791
9130
|
for (int h = ih0; h < ih1; ++h) {
|
|
8792
9131
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
8793
9132
|
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
|
9133
|
+
const int g = h / (nh / ng); // repeat_interleave
|
|
8794
9134
|
|
|
8795
9135
|
// dim
|
|
8796
9136
|
for (int i1 = 0; i1 < nr; ++i1) {
|
|
@@ -8805,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8805
9145
|
// TODO: what happens when (d_state % svcntw()) != 0?
|
|
8806
9146
|
for (int64_t k = 0; k < nc; k += svcntw()) {
|
|
8807
9147
|
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
|
8808
|
-
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k +
|
|
8809
|
-
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k +
|
|
9148
|
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
|
|
9149
|
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
|
|
8810
9150
|
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
|
8811
9151
|
|
|
8812
9152
|
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
|
@@ -8826,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
8826
9166
|
// d_state
|
|
8827
9167
|
for (int i0 = 0; i0 < nc; ++i0) {
|
|
8828
9168
|
const int i = i0 + ii*nc;
|
|
8829
|
-
const int ig = i0 +
|
|
9169
|
+
const int ig = i0 + g*nc;
|
|
8830
9170
|
// state = prev_state * dA + dB * x
|
|
8831
9171
|
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
|
8832
9172
|
// y = rowwise_dotprod(state, C)
|
|
@@ -9080,6 +9420,10 @@ void ggml_compute_forward_glu(
|
|
|
9080
9420
|
{
|
|
9081
9421
|
ggml_compute_forward_swiglu(params, dst);
|
|
9082
9422
|
} break;
|
|
9423
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
9424
|
+
{
|
|
9425
|
+
ggml_compute_forward_swiglu_oai(params, dst);
|
|
9426
|
+
} break;
|
|
9083
9427
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
9084
9428
|
{
|
|
9085
9429
|
ggml_compute_forward_geglu_erf(params, dst);
|
|
@@ -9683,8 +10027,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|
|
9683
10027
|
int64_t h_stride_2d = head_size * head_size;
|
|
9684
10028
|
|
|
9685
10029
|
#if defined(GGML_SIMD)
|
|
9686
|
-
#if defined(__ARM_FEATURE_SVE)
|
|
9687
|
-
// scalar Route to scalar implementation //TODO: Write SVE code
|
|
10030
|
+
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
|
|
10031
|
+
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
|
|
9688
10032
|
for (int64_t t = 0; t < T; t++) {
|
|
9689
10033
|
int64_t t_offset = t * t_stride;
|
|
9690
10034
|
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
|
@@ -10132,6 +10476,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
10132
10476
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
10133
10477
|
|
|
10134
10478
|
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
|
10479
|
+
|
|
10135
10480
|
const float alpha = adamw_params_ptr[0];
|
|
10136
10481
|
const float beta1 = adamw_params_ptr[1];
|
|
10137
10482
|
const float beta2 = adamw_params_ptr[2];
|
|
@@ -10139,7 +10484,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
10139
10484
|
const float wd = adamw_params_ptr[4];
|
|
10140
10485
|
const float beta1h = adamw_params_ptr[5];
|
|
10141
10486
|
const float beta2h = adamw_params_ptr[6];
|
|
10142
|
-
|
|
10487
|
+
const float keep = 1.f - alpha * wd;
|
|
10143
10488
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
10144
10489
|
const int64_t i03 = ir/(ne02*ne01);
|
|
10145
10490
|
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
@@ -10162,7 +10507,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
10162
10507
|
// The weight decay is applied independently of the Adam momenta m and v.
|
|
10163
10508
|
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
|
10164
10509
|
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
|
10165
|
-
w[i00] = w[i00]*
|
|
10510
|
+
w[i00] = w[i00] * keep - alpha * mh / vh;
|
|
10166
10511
|
}
|
|
10167
10512
|
}
|
|
10168
10513
|
}
|
|
@@ -10184,3 +10529,63 @@ void ggml_compute_forward_opt_step_adamw(
|
|
|
10184
10529
|
}
|
|
10185
10530
|
}
|
|
10186
10531
|
}
|
|
10532
|
+
|
|
10533
|
+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
10534
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10535
|
+
const ggml_tensor * src0_grad = dst->src[1];
|
|
10536
|
+
const ggml_tensor * sgd_params = dst->src[2];
|
|
10537
|
+
|
|
10538
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
|
10539
|
+
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
|
|
10540
|
+
|
|
10541
|
+
const int ith = params->ith;
|
|
10542
|
+
const int nth = params->nth;
|
|
10543
|
+
|
|
10544
|
+
const int nr = ggml_nrows(src0);
|
|
10545
|
+
|
|
10546
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
10547
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
|
10548
|
+
|
|
10549
|
+
// rows per thread
|
|
10550
|
+
const int dr = (nr + nth - 1) / nth;
|
|
10551
|
+
|
|
10552
|
+
// row range for this thread
|
|
10553
|
+
const int ir0 = dr * ith;
|
|
10554
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
10555
|
+
|
|
10556
|
+
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
|
10557
|
+
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
|
|
10558
|
+
const float alpha = sgd_params_ptr[0];
|
|
10559
|
+
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
|
10560
|
+
|
|
10561
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
10562
|
+
const int64_t i03 = ir / (ne02 * ne01);
|
|
10563
|
+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
|
10564
|
+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
|
10565
|
+
|
|
10566
|
+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
|
10567
|
+
|
|
10568
|
+
float * w = (float *) ((char *) src0->data + offset); // weight
|
|
10569
|
+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
|
10570
|
+
|
|
10571
|
+
for (int i00 = 0; i00 < ne00; ++i00) {
|
|
10572
|
+
w[i00] = w[i00] * keep - alpha * g[i00];
|
|
10573
|
+
}
|
|
10574
|
+
}
|
|
10575
|
+
}
|
|
10576
|
+
|
|
10577
|
+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
10578
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10579
|
+
|
|
10580
|
+
switch (src0->type) {
|
|
10581
|
+
case GGML_TYPE_F32:
|
|
10582
|
+
{
|
|
10583
|
+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
|
10584
|
+
}
|
|
10585
|
+
break;
|
|
10586
|
+
default:
|
|
10587
|
+
{
|
|
10588
|
+
GGML_ABORT("fatal error - sgd is F32 only");
|
|
10589
|
+
}
|
|
10590
|
+
}
|
|
10591
|
+
}
|