@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#include "generic_binary_head.comp"
|
|
4
|
+
#include "types.comp"
|
|
5
|
+
|
|
6
|
+
#extension GL_EXT_control_flow_attributes : enable
|
|
7
|
+
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
|
8
|
+
#extension GL_KHR_shader_subgroup_basic : enable
|
|
9
|
+
|
|
10
|
+
#define BLOCK_SIZE 128
|
|
11
|
+
|
|
12
|
+
layout (constant_id = 1) const bool do_multiply = false;
|
|
13
|
+
|
|
14
|
+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
|
15
|
+
|
|
16
|
+
layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};
|
|
17
|
+
|
|
18
|
+
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
|
19
|
+
|
|
20
|
+
void main() {
|
|
21
|
+
const uint ncols = p.ne00;
|
|
22
|
+
const uint nrows = gl_NumWorkGroups.x;
|
|
23
|
+
const uint nchannels = gl_NumWorkGroups.y;
|
|
24
|
+
|
|
25
|
+
const uint row = 0;
|
|
26
|
+
const uint channel = gl_WorkGroupID.y;
|
|
27
|
+
const uint samp = gl_WorkGroupID.z;
|
|
28
|
+
// The work is split across multiple workgroups in the x dimension. Each invocation
|
|
29
|
+
// processes one element
|
|
30
|
+
const uint tid = gl_GlobalInvocationID.x;
|
|
31
|
+
|
|
32
|
+
const uint stride_row = p.nb01;
|
|
33
|
+
const uint stride_channel = p.nb02;
|
|
34
|
+
const uint stride_sample = p.nb03;
|
|
35
|
+
|
|
36
|
+
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
|
37
|
+
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
|
38
|
+
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
|
39
|
+
|
|
40
|
+
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
|
41
|
+
|
|
42
|
+
uint32_t num_partials = p.param3;
|
|
43
|
+
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
|
|
44
|
+
sum += partial_sums[i];
|
|
45
|
+
}
|
|
46
|
+
sum = subgroupAdd(sum);
|
|
47
|
+
|
|
48
|
+
uint col = tid;
|
|
49
|
+
if (col >= ncols) {
|
|
50
|
+
return;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
|
54
|
+
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
|
55
|
+
|
|
56
|
+
if (do_multiply) {
|
|
57
|
+
if (ncols > p.ne10) {
|
|
58
|
+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
|
59
|
+
} else {
|
|
60
|
+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
|
61
|
+
}
|
|
62
|
+
} else {
|
|
63
|
+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
|
64
|
+
}
|
|
65
|
+
}
|
|
@@ -20,6 +20,7 @@ layout (push_constant) uniform parameter
|
|
|
20
20
|
float m1;
|
|
21
21
|
uint n_head_log2;
|
|
22
22
|
uint nrows_x;
|
|
23
|
+
uint has_sinks;
|
|
23
24
|
} p;
|
|
24
25
|
|
|
25
26
|
#include "types.comp"
|
|
@@ -29,7 +30,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
29
30
|
|
|
30
31
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
|
31
32
|
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
|
32
|
-
layout (binding = 2) buffer
|
|
33
|
+
layout (binding = 2) readonly buffer Z {float data_c[];};
|
|
34
|
+
layout (binding = 3) buffer D {D_TYPE data_d[];};
|
|
33
35
|
|
|
34
36
|
shared FLOAT_TYPE vals[BLOCK_SIZE];
|
|
35
37
|
|
|
@@ -60,13 +62,13 @@ void soft_max(uint num_iters) {
|
|
|
60
62
|
const uint h = (rowx / p.ne01) % p.ne02; // head index
|
|
61
63
|
|
|
62
64
|
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
|
|
63
|
-
const uint exp
|
|
65
|
+
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
|
|
64
66
|
|
|
65
67
|
slope = pow(base, exp);
|
|
66
68
|
}
|
|
67
69
|
|
|
68
70
|
// Find max
|
|
69
|
-
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
|
|
71
|
+
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
|
|
70
72
|
|
|
71
73
|
// Cache values while we compute the max, so we don't need to read them
|
|
72
74
|
// again when we're ready to compute exp(x-max).
|
|
@@ -148,6 +150,10 @@ void soft_max(uint num_iters) {
|
|
|
148
150
|
}
|
|
149
151
|
sum = vals[0];
|
|
150
152
|
|
|
153
|
+
if (p.has_sinks != 0) {
|
|
154
|
+
sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
|
|
155
|
+
}
|
|
156
|
+
|
|
151
157
|
FLOAT_TYPE rcpdivisor = 1.0/sum;
|
|
152
158
|
|
|
153
159
|
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#include "types.comp"
|
|
4
|
+
#include "generic_unary_head.comp"
|
|
5
|
+
|
|
6
|
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
7
|
+
|
|
8
|
+
void main() {
|
|
9
|
+
const uint idx = get_idx();
|
|
10
|
+
|
|
11
|
+
if (idx >= p.ne) {
|
|
12
|
+
return;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
|
16
|
+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
|
|
17
|
+
}
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
#version 450
|
|
2
2
|
|
|
3
|
-
#include "generic_head.comp"
|
|
4
3
|
#include "types.comp"
|
|
5
4
|
|
|
6
5
|
#extension GL_EXT_control_flow_attributes : enable
|
|
6
|
+
|
|
7
7
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
8
8
|
|
|
9
9
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
@@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|
|
11
11
|
|
|
12
12
|
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
13
13
|
|
|
14
|
+
layout (push_constant) uniform parameter
|
|
15
|
+
{
|
|
16
|
+
uint n_cols;
|
|
17
|
+
uint ne01, ne02;
|
|
18
|
+
uint nb01, nb02, nb03;
|
|
19
|
+
uint nb11, nb12, nb13;
|
|
20
|
+
float weight;
|
|
21
|
+
uint misalign_offsets;
|
|
22
|
+
uint ne0_12mp, ne0_12L;
|
|
23
|
+
uint ne0_1mp, ne0_1L;
|
|
24
|
+
} p;
|
|
25
|
+
|
|
26
|
+
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
|
27
|
+
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
|
28
|
+
|
|
29
|
+
// see init_fastdiv_values in ggml-vulkan.cpp
|
|
30
|
+
uint fastdiv(uint n, uint mp, uint L) {
|
|
31
|
+
uint msbs, lsbs;
|
|
32
|
+
// msbs = mulhi(n, mp)
|
|
33
|
+
umulExtended(n, mp, msbs, lsbs);
|
|
34
|
+
return (msbs + n) >> L;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
14
38
|
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
|
15
39
|
|
|
16
40
|
void main() {
|
|
17
41
|
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
|
18
42
|
const uint col = gl_LocalInvocationID.x;
|
|
43
|
+
const float weight = p.weight;
|
|
44
|
+
|
|
45
|
+
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
|
46
|
+
const uint i03_offset = i03 * p.ne01*p.ne02;
|
|
47
|
+
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
|
48
|
+
const uint i01 = row - i03_offset - i02*p.ne01;
|
|
49
|
+
|
|
50
|
+
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
|
51
|
+
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
|
19
52
|
|
|
20
|
-
tmp[col] = FLOAT_TYPE(0.
|
|
53
|
+
tmp[col] = FLOAT_TYPE(0.0);
|
|
21
54
|
|
|
22
|
-
for (uint i = col; i < p.
|
|
23
|
-
tmp[col] += FLOAT_TYPE(data_a[
|
|
55
|
+
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
|
|
56
|
+
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);
|
|
24
57
|
}
|
|
25
58
|
|
|
26
59
|
barrier();
|
|
@@ -32,6 +65,6 @@ void main() {
|
|
|
32
65
|
}
|
|
33
66
|
|
|
34
67
|
if (col == 0) {
|
|
35
|
-
data_d[
|
|
68
|
+
data_d[dst_idx] = D_TYPE(tmp[0] * weight);
|
|
36
69
|
}
|
|
37
70
|
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#include "glu_head.comp"
|
|
4
|
+
|
|
5
|
+
float op(float a, float b) {
|
|
6
|
+
float xi = min(a, p.limit);
|
|
7
|
+
float gi = max(min(b, p.limit), -p.limit);
|
|
8
|
+
|
|
9
|
+
float out_glu = xi / (1.0f + exp(-xi * p.alpha));
|
|
10
|
+
out_glu = out_glu * (1.0f + gi);
|
|
11
|
+
return out_glu;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
#include "glu_main.comp"
|
|
@@ -1337,6 +1337,29 @@ struct block_iq4_nl_packed16
|
|
|
1337
1337
|
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
|
1338
1338
|
#endif
|
|
1339
1339
|
|
|
1340
|
+
#define QUANT_K_MXFP4 32
|
|
1341
|
+
#define QUANT_R_MXFP4 2
|
|
1342
|
+
|
|
1343
|
+
struct block_mxfp4
|
|
1344
|
+
{
|
|
1345
|
+
uint8_t e;
|
|
1346
|
+
uint8_t qs[QUANT_K_MXFP4/2];
|
|
1347
|
+
};
|
|
1348
|
+
|
|
1349
|
+
//struct block_mxfp4_packed16
|
|
1350
|
+
//{
|
|
1351
|
+
// uint8_t e;
|
|
1352
|
+
// uint16_t qs[QUANT_K_MXFP4/2/2];
|
|
1353
|
+
//};
|
|
1354
|
+
|
|
1355
|
+
#if defined(DATA_A_MXFP4)
|
|
1356
|
+
#define QUANT_K QUANT_K_MXFP4
|
|
1357
|
+
#define QUANT_R QUANT_R_MXFP4
|
|
1358
|
+
#define QUANT_AUXF 1
|
|
1359
|
+
#define A_TYPE block_mxfp4
|
|
1360
|
+
//#define A_TYPE_PACKED16 block_mxfp4_packed16
|
|
1361
|
+
#endif
|
|
1362
|
+
|
|
1340
1363
|
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
|
1341
1364
|
const int8_t kvalues_iq4nl_const[16] = {
|
|
1342
1365
|
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
|
|
@@ -1356,6 +1379,25 @@ void init_iq_shmem(uvec3 wgsize)
|
|
|
1356
1379
|
}
|
|
1357
1380
|
#endif
|
|
1358
1381
|
|
|
1382
|
+
#if defined(DATA_A_MXFP4)
|
|
1383
|
+
const FLOAT_TYPE kvalues_mxfp4_const[16] = {
|
|
1384
|
+
FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
|
|
1385
|
+
FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
|
|
1386
|
+
};
|
|
1387
|
+
|
|
1388
|
+
shared FLOAT_TYPE kvalues_mxfp4[16];
|
|
1389
|
+
|
|
1390
|
+
#define NEEDS_INIT_IQ_SHMEM
|
|
1391
|
+
void init_iq_shmem(uvec3 wgsize)
|
|
1392
|
+
{
|
|
1393
|
+
// copy the table into shared memory and sync
|
|
1394
|
+
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
|
|
1395
|
+
kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
|
|
1396
|
+
}
|
|
1397
|
+
barrier();
|
|
1398
|
+
}
|
|
1399
|
+
#endif
|
|
1400
|
+
|
|
1359
1401
|
// returns the bfloat value in the low 16b.
|
|
1360
1402
|
// See ggml_compute_fp32_to_bf16
|
|
1361
1403
|
uint32_t fp32_to_bf16(float f)
|
|
@@ -1370,4 +1412,17 @@ float bf16_to_fp32(uint32_t u)
|
|
|
1370
1412
|
return uintBitsToFloat(u << 16);
|
|
1371
1413
|
}
|
|
1372
1414
|
|
|
1415
|
+
float e8m0_to_fp32(uint8_t x) {
|
|
1416
|
+
uint32_t bits;
|
|
1417
|
+
|
|
1418
|
+
if (x == 0) {
|
|
1419
|
+
bits = 0x00400000;
|
|
1420
|
+
} else {
|
|
1421
|
+
bits = x;
|
|
1422
|
+
bits = bits << 23;
|
|
1423
|
+
}
|
|
1424
|
+
|
|
1425
|
+
return uintBitsToFloat(bits);
|
|
1426
|
+
}
|
|
1427
|
+
|
|
1373
1428
|
#endif // !defined(GGML_TYPES_COMP)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
#ifndef UTILS_COMP
|
|
2
|
+
#define UTILS_COMP
|
|
3
|
+
|
|
4
|
+
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
|
|
5
|
+
uint fastmod(uint a, uint b) {
|
|
6
|
+
if ((b & (b-1)) == 0) {
|
|
7
|
+
return a & (b-1);
|
|
8
|
+
}
|
|
9
|
+
return a % b;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
uint fastdiv(uint a, uint b) {
|
|
13
|
+
return (a < b) ? 0 : (a / b);
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03, uint ne00, uint ne01, uint ne02, uint ne03) {
|
|
17
|
+
i03 = fastdiv(idx, (ne02*ne01*ne00));
|
|
18
|
+
const uint i03_offset = i03 * ne02*ne01*ne00;
|
|
19
|
+
i02 = fastdiv((idx - i03_offset), (ne01*ne00));
|
|
20
|
+
const uint i02_offset = i02*ne01*ne00;
|
|
21
|
+
i01 = (idx - i03_offset - i02_offset) / ne00;
|
|
22
|
+
i00 = idx - i03_offset - i02_offset - i01*ne00;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
#endif // UTILS_COMP
|
|
@@ -64,9 +64,16 @@ const std::vector<std::string> type_names = {
|
|
|
64
64
|
"iq3_s",
|
|
65
65
|
"iq4_xs",
|
|
66
66
|
"iq4_nl",
|
|
67
|
+
"mxfp4",
|
|
67
68
|
"bf16",
|
|
68
69
|
};
|
|
69
70
|
|
|
71
|
+
enum MatMulIdType {
|
|
72
|
+
NONE,
|
|
73
|
+
DEFAULT,
|
|
74
|
+
SUBGROUP,
|
|
75
|
+
};
|
|
76
|
+
|
|
70
77
|
namespace {
|
|
71
78
|
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
|
72
79
|
#ifdef _WIN32
|
|
@@ -118,7 +125,7 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
|
|
|
118
125
|
CloseHandle(pi.hProcess);
|
|
119
126
|
CloseHandle(pi.hThread);
|
|
120
127
|
#else
|
|
121
|
-
int stdout_pipe[2];
|
|
128
|
+
int stdout_pipe[2];
|
|
122
129
|
int stderr_pipe[2];
|
|
123
130
|
|
|
124
131
|
if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
|
|
@@ -222,7 +229,8 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c
|
|
|
222
229
|
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
|
|
223
230
|
|
|
224
231
|
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
|
|
225
|
-
|
|
232
|
+
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
|
|
233
|
+
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O";
|
|
226
234
|
|
|
227
235
|
#ifdef _WIN32
|
|
228
236
|
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
|
|
@@ -291,7 +299,7 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
|
|
291
299
|
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
|
|
292
300
|
}
|
|
293
301
|
|
|
294
|
-
void matmul_shaders(bool fp16,
|
|
302
|
+
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
|
|
295
303
|
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
|
296
304
|
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
|
297
305
|
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
|
@@ -301,9 +309,13 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
|
301
309
|
};
|
|
302
310
|
std::string shader_name = "matmul";
|
|
303
311
|
|
|
304
|
-
if (
|
|
312
|
+
if (matmul_id_type == MatMulIdType::DEFAULT) {
|
|
305
313
|
base_dict["MUL_MAT_ID"] = "1";
|
|
306
314
|
shader_name = "matmul_id";
|
|
315
|
+
} else if (matmul_id_type == MatMulIdType::SUBGROUP) {
|
|
316
|
+
base_dict["MUL_MAT_ID"] = "1";
|
|
317
|
+
base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
|
|
318
|
+
shader_name = "matmul_id_subgroup";
|
|
307
319
|
}
|
|
308
320
|
|
|
309
321
|
if (fp16) {
|
|
@@ -362,7 +374,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
|
362
374
|
std::string load_vec_quant = "2";
|
|
363
375
|
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
|
364
376
|
load_vec_quant = "8";
|
|
365
|
-
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
|
|
377
|
+
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
|
|
366
378
|
load_vec_quant = "4";
|
|
367
379
|
|
|
368
380
|
if (tname == "bf16") {
|
|
@@ -387,7 +399,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|
|
387
399
|
}
|
|
388
400
|
|
|
389
401
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
390
|
-
if (!coopmat && !coopmat2 &&
|
|
402
|
+
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
|
391
403
|
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
|
392
404
|
}
|
|
393
405
|
#endif
|
|
@@ -399,26 +411,28 @@ void process_shaders() {
|
|
|
399
411
|
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
|
400
412
|
|
|
401
413
|
// matmul
|
|
402
|
-
for (const
|
|
414
|
+
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
|
403
415
|
// No coopmats
|
|
404
416
|
// fp32
|
|
405
|
-
matmul_shaders(false,
|
|
417
|
+
matmul_shaders(false, matmul_id_type, false, false, false);
|
|
406
418
|
|
|
407
419
|
// fp16, fp32acc and fp16acc
|
|
408
|
-
matmul_shaders(true,
|
|
409
|
-
matmul_shaders(true,
|
|
420
|
+
matmul_shaders(true, matmul_id_type, false, false, false);
|
|
421
|
+
matmul_shaders(true, matmul_id_type, false, false, true);
|
|
410
422
|
|
|
423
|
+
if (matmul_id_type != MatMulIdType::DEFAULT) {
|
|
411
424
|
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
425
|
+
// Coopmat, fp32acc and fp16acc
|
|
426
|
+
matmul_shaders(true, matmul_id_type, true, false, false);
|
|
427
|
+
matmul_shaders(true, matmul_id_type, true, false, true);
|
|
415
428
|
#endif
|
|
416
429
|
|
|
417
430
|
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
431
|
+
// Coopmat2, fp32acc and fp16acc
|
|
432
|
+
matmul_shaders(true, matmul_id_type, false, true, false);
|
|
433
|
+
matmul_shaders(true, matmul_id_type, false, true, true);
|
|
421
434
|
#endif
|
|
435
|
+
}
|
|
422
436
|
}
|
|
423
437
|
|
|
424
438
|
// flash attention
|
|
@@ -471,6 +485,9 @@ void process_shaders() {
|
|
|
471
485
|
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
|
472
486
|
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
|
473
487
|
|
|
488
|
+
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
|
489
|
+
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
|
490
|
+
|
|
474
491
|
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
|
475
492
|
|
|
476
493
|
// Dequant shaders
|
|
@@ -498,6 +515,7 @@ void process_shaders() {
|
|
|
498
515
|
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
499
516
|
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
500
517
|
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
518
|
+
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
501
519
|
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
502
520
|
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
503
521
|
|
|
@@ -533,13 +551,15 @@ void process_shaders() {
|
|
|
533
551
|
s += std::string(dst_f16 ? "_f16" : "_f32");
|
|
534
552
|
return s;
|
|
535
553
|
};
|
|
536
|
-
for (std::string op : {"add", "sub", "mul", "div"}) {
|
|
554
|
+
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
|
|
537
555
|
for (auto src0_f16 : {false, true}) {
|
|
538
556
|
for (auto src1_f16 : {false, true}) {
|
|
539
557
|
for (auto dst_f16 : {false, true}) {
|
|
540
558
|
for (auto rte : {false, true}) {
|
|
559
|
+
auto source = op == "add_rms" ? std::string("add") : op;
|
|
541
560
|
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
|
542
|
-
|
|
561
|
+
auto add_rms = op == "add_rms" ? "1" : "0";
|
|
562
|
+
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
|
|
543
563
|
}
|
|
544
564
|
}
|
|
545
565
|
}
|
|
@@ -565,6 +585,8 @@ void process_shaders() {
|
|
|
565
585
|
|
|
566
586
|
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
567
587
|
|
|
588
|
+
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
589
|
+
|
|
568
590
|
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
569
591
|
|
|
570
592
|
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
@@ -579,6 +601,8 @@ void process_shaders() {
|
|
|
579
601
|
|
|
580
602
|
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
581
603
|
|
|
604
|
+
string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
605
|
+
string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
582
606
|
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
583
607
|
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
584
608
|
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
@@ -602,6 +626,8 @@ void process_shaders() {
|
|
|
602
626
|
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
|
603
627
|
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
|
604
628
|
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
|
629
|
+
string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
|
630
|
+
string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
|
605
631
|
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
|
606
632
|
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
|
607
633
|
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
|
@@ -654,14 +680,31 @@ void process_shaders() {
|
|
|
654
680
|
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
|
655
681
|
|
|
656
682
|
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
|
683
|
+
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
|
684
|
+
|
|
685
|
+
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
|
686
|
+
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
|
687
|
+
|
|
688
|
+
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
|
|
689
|
+
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
|
|
657
690
|
|
|
658
|
-
|
|
691
|
+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
692
|
+
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
|
|
693
|
+
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
|
|
694
|
+
#endif
|
|
659
695
|
|
|
660
696
|
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
|
661
697
|
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
|
698
|
+
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
|
699
|
+
string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
|
662
700
|
|
|
663
701
|
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
664
702
|
|
|
703
|
+
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
704
|
+
|
|
705
|
+
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
|
706
|
+
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
|
707
|
+
|
|
665
708
|
for (auto &c : compiles) {
|
|
666
709
|
c.wait();
|
|
667
710
|
}
|
|
@@ -718,7 +761,7 @@ void write_output_files() {
|
|
|
718
761
|
}
|
|
719
762
|
|
|
720
763
|
std::string suffixes[2] = {"_f32", "_f16"};
|
|
721
|
-
for (const char *op : {"add", "sub", "mul", "div"}) {
|
|
764
|
+
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
|
|
722
765
|
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
|
|
723
766
|
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
|
|
724
767
|
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
|
|
@@ -770,6 +813,18 @@ void write_output_files() {
|
|
|
770
813
|
fputs(data.c_str(), src);
|
|
771
814
|
fputs(len.c_str(), src);
|
|
772
815
|
}
|
|
816
|
+
|
|
817
|
+
for (const std::string& btype : {"f16", "f32"}) {
|
|
818
|
+
for (const auto& tname : type_names) {
|
|
819
|
+
fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[2];\n", tname.c_str(), btype.c_str());
|
|
820
|
+
fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[2];\n", tname.c_str(), btype.c_str());
|
|
821
|
+
std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data};\n";
|
|
822
|
+
std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len};\n";
|
|
823
|
+
fputs(data.c_str(), src);
|
|
824
|
+
fputs(len.c_str(), src);
|
|
825
|
+
}
|
|
826
|
+
}
|
|
827
|
+
|
|
773
828
|
fclose(hdr);
|
|
774
829
|
fclose(src);
|
|
775
830
|
}
|
|
@@ -20,8 +20,8 @@ add_custom_command(
|
|
|
20
20
|
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
|
|
21
21
|
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
|
|
22
22
|
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
|
23
|
-
--
|
|
24
|
-
--
|
|
23
|
+
--input_dir "${SHADER_DIR}"
|
|
24
|
+
--output_file "${SHADER_HEADER}"
|
|
25
25
|
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
|
26
26
|
VERBATIM
|
|
27
27
|
)
|