@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -44,6 +44,7 @@
|
|
|
44
44
|
#include "ggml-sycl/set_rows.hpp"
|
|
45
45
|
#include "ggml-sycl/sycl_hw.hpp"
|
|
46
46
|
#include "ggml-sycl/getrows.hpp"
|
|
47
|
+
#include "ggml-sycl/quantize.hpp"
|
|
47
48
|
#include "ggml.h"
|
|
48
49
|
|
|
49
50
|
static bool g_sycl_loaded = false;
|
|
@@ -1373,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
|
|
|
1373
1374
|
|
|
1374
1375
|
|
|
1375
1376
|
|
|
1376
|
-
template<int QUANT_BLOCK_TILE>
|
|
1377
|
-
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
|
|
1378
|
-
const sycl::nd_item<3> &item_ct1) {
|
|
1379
|
-
const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
1380
|
-
item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
|
|
1381
|
-
|
|
1382
|
-
if (ix >= kx_padded) {
|
|
1383
|
-
return;
|
|
1384
|
-
}
|
|
1385
|
-
|
|
1386
|
-
const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
1387
|
-
item_ct1.get_local_id(1);
|
|
1388
|
-
|
|
1389
|
-
const int i_padded = iy*kx_padded + ix;
|
|
1390
|
-
|
|
1391
|
-
block_q8_1 * y = (block_q8_1 *) vy;
|
|
1392
|
-
|
|
1393
|
-
const int ib = i_padded / QK8_1; // block index
|
|
1394
|
-
const int iqs = i_padded % QK8_1; // quant index
|
|
1395
|
-
typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
|
|
1396
|
-
typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
|
|
1397
|
-
TC zeros;
|
|
1398
|
-
TQ qzeros;
|
|
1399
|
-
#pragma unroll
|
|
1400
|
-
for (int i = 0; i < QUANT_BLOCK_TILE; i++)
|
|
1401
|
-
{
|
|
1402
|
-
zeros[i] = 0.f;
|
|
1403
|
-
qzeros[i] = 0;
|
|
1404
|
-
}
|
|
1405
|
-
const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
|
|
1406
|
-
float sum = xi[0];
|
|
1407
|
-
float amax = sycl::fabs(xi[0]);
|
|
1408
|
-
#pragma unroll
|
|
1409
|
-
for (int i = 1; i < QUANT_BLOCK_TILE; i++)
|
|
1410
|
-
{
|
|
1411
|
-
sum += xi[i];
|
|
1412
|
-
amax = sycl::fmax(sycl::fabs(xi[i]), amax);
|
|
1413
|
-
}
|
|
1414
|
-
sum = warp_reduce_sum(sum, item_ct1);
|
|
1415
|
-
amax = warp_reduce_max(amax, item_ct1);
|
|
1416
|
-
|
|
1417
|
-
const float d = amax / 127;
|
|
1418
|
-
TQ q = qzeros;
|
|
1419
|
-
if (amax != 0.0f)
|
|
1420
|
-
{
|
|
1421
|
-
#pragma unroll
|
|
1422
|
-
for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
|
|
1423
|
-
q[i] = sycl::round(xi[i] / d);
|
|
1424
|
-
}
|
|
1425
|
-
}
|
|
1426
|
-
|
|
1427
|
-
*(TQ *)&y[ib].qs[iqs] = q;
|
|
1428
|
-
|
|
1429
|
-
if (iqs > 0) {
|
|
1430
|
-
return;
|
|
1431
|
-
}
|
|
1432
|
-
|
|
1433
|
-
reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
|
|
1434
|
-
reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
|
|
1435
|
-
}
|
|
1436
|
-
|
|
1437
|
-
template <int ElementsPerWI>
|
|
1438
|
-
static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
|
|
1439
|
-
const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
|
|
1440
|
-
/*
|
|
1441
|
-
Quantizes and reorders the resultant q8 tensor in a per row fashion
|
|
1442
|
-
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
|
|
1443
|
-
*/
|
|
1444
|
-
|
|
1445
|
-
auto subgroup_id = it.get_group(0);
|
|
1446
|
-
auto wi_id = it.get_local_id(0);
|
|
1447
|
-
|
|
1448
|
-
const int num_blocks_per_row = kx / QK8_1;
|
|
1449
|
-
auto row = subgroup_id / num_blocks_per_row;
|
|
1450
|
-
auto col = subgroup_id % num_blocks_per_row;
|
|
1451
|
-
|
|
1452
|
-
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
|
|
1453
|
-
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
|
|
1454
|
-
|
|
1455
|
-
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
|
|
1456
|
-
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
|
|
1457
|
-
|
|
1458
|
-
sycl::vec<float, ElementsPerWI> wi_f32_vals;
|
|
1459
|
-
sycl::vec<int8_t, ElementsPerWI> quantized_values;
|
|
1460
|
-
|
|
1461
|
-
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
|
|
1462
|
-
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
|
|
1463
|
-
|
|
1464
|
-
float sum = 0.0f;
|
|
1465
|
-
float amax = 0.0f;
|
|
1466
|
-
|
|
1467
|
-
#pragma unroll(ElementsPerWI)
|
|
1468
|
-
for (int i = 0; i < ElementsPerWI; i++) {
|
|
1469
|
-
sum += wi_f32_vals[i];
|
|
1470
|
-
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
|
|
1471
|
-
quantized_values[i] = 0;
|
|
1472
|
-
}
|
|
1473
|
-
sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
|
|
1474
|
-
amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
|
|
1475
|
-
float d = amax == 0 ? 1 : amax / 127;
|
|
1476
|
-
|
|
1477
|
-
#pragma unroll(ElementsPerWI)
|
|
1478
|
-
for (int i = 0; i < ElementsPerWI; i++) {
|
|
1479
|
-
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
|
|
1480
|
-
}
|
|
1481
|
-
|
|
1482
|
-
d = amax == 0 ? 0 : d;
|
|
1483
|
-
|
|
1484
|
-
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
|
|
1485
|
-
if (wi_id == 0) {
|
|
1486
|
-
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
|
|
1487
|
-
}
|
|
1488
|
-
}
|
|
1489
|
-
|
|
1490
1377
|
static void mul_mat_p021_f16_f32(
|
|
1491
1378
|
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
|
|
1492
1379
|
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
|
|
@@ -1770,32 +1657,6 @@ static void pool2d_nchw_kernel(
|
|
|
1770
1657
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
|
1771
1658
|
}
|
|
1772
1659
|
|
|
1773
|
-
static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
|
|
1774
|
-
bool reorder_q8_tensor, queue_ptr stream) {
|
|
1775
|
-
if (reorder_q8_tensor) {
|
|
1776
|
-
auto local_range = std::size_t(WARP_SIZE);
|
|
1777
|
-
auto num_quant_blocks = ky * (kx / QK8_1);
|
|
1778
|
-
auto global_range = num_quant_blocks * local_range;
|
|
1779
|
-
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
|
|
1780
|
-
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1781
|
-
quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
|
|
1782
|
-
});
|
|
1783
|
-
} else {
|
|
1784
|
-
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
|
|
1785
|
-
const sycl::range<3> num_blocks(1, ky, block_num_x);
|
|
1786
|
-
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
|
|
1787
|
-
static_assert(QK8_1 % WARP_SIZE == 0);
|
|
1788
|
-
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
|
|
1789
|
-
{
|
|
1790
|
-
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
1791
|
-
|
|
1792
|
-
stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
|
|
1793
|
-
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
1794
|
-
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
|
1795
|
-
});
|
|
1796
|
-
}
|
|
1797
|
-
}
|
|
1798
|
-
}
|
|
1799
1660
|
|
|
1800
1661
|
static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
|
1801
1662
|
float *dst, const int ncols_x,
|
|
@@ -2372,10 +2233,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
|
|
2372
2233
|
peer_access_enabled = enable_peer_access;
|
|
2373
2234
|
}
|
|
2374
2235
|
|
|
2236
|
+
template <template <int> typename quantize_f>
|
|
2375
2237
|
static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
2376
2238
|
const ggml_tensor *src1, ggml_tensor *dst,
|
|
2377
|
-
ggml_sycl_op_mul_mat_t op
|
|
2378
|
-
const bool convert_src1_to_q8_1) try {
|
|
2239
|
+
ggml_sycl_op_mul_mat_t op) try {
|
|
2379
2240
|
|
|
2380
2241
|
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
|
|
2381
2242
|
|
|
@@ -2470,6 +2331,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2470
2331
|
}
|
|
2471
2332
|
}
|
|
2472
2333
|
|
|
2334
|
+
constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
|
|
2335
|
+
no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
|
|
2473
2336
|
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
|
|
2474
2337
|
if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
|
|
2475
2338
|
continue;
|
|
@@ -2495,20 +2358,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2495
2358
|
dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
|
|
2496
2359
|
}
|
|
2497
2360
|
|
|
2498
|
-
if (
|
|
2361
|
+
if constexpr(quantize_enabled) {
|
|
2499
2362
|
dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
|
|
2500
2363
|
|
|
2501
2364
|
if (src1_on_device && src1_is_contiguous) {
|
|
2502
|
-
bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
|
|
2503
2365
|
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
|
|
2504
2366
|
/*num_src=*/2, " : converting src1 to Q8_1");
|
|
2505
|
-
|
|
2506
|
-
|
|
2507
|
-
|
|
2508
|
-
|
|
2509
|
-
|
|
2510
|
-
|
|
2511
|
-
|
|
2367
|
+
try {
|
|
2368
|
+
quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
|
|
2369
|
+
} catch (sycl::exception const &exc) {
|
|
2370
|
+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
|
|
2371
|
+
<< ", line:" << __LINE__ << std::endl;
|
|
2372
|
+
std::exit(1);
|
|
2373
|
+
}
|
|
2512
2374
|
}
|
|
2513
2375
|
}
|
|
2514
2376
|
|
|
@@ -2524,11 +2386,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2524
2386
|
// here an event is recorded that signals that the main device has finished calculating the input data
|
|
2525
2387
|
if (split && used_devices > 1) {
|
|
2526
2388
|
ggml_sycl_set_device(ctx.device);
|
|
2527
|
-
/*
|
|
2528
|
-
DPCT1024:91: The original code returned the error code that was further
|
|
2529
|
-
consumed by the program logic. This original code was replaced with 0.
|
|
2530
|
-
You may need to rewrite the program logic consuming the error code.
|
|
2531
|
-
*/
|
|
2532
2389
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
|
2533
2390
|
*src0_extra->events[ctx.device][0] =
|
|
2534
2391
|
ctx.stream()->ext_oneapi_submit_barrier()));
|
|
@@ -2552,11 +2409,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2552
2409
|
|
|
2553
2410
|
// wait for main GPU data if necessary
|
|
2554
2411
|
if (split && (i != ctx.device || is != 0)) {
|
|
2555
|
-
/*
|
|
2556
|
-
DPCT1009:163: SYCL uses exceptions to report errors and does not
|
|
2557
|
-
use the error codes. The original code was commented out and a
|
|
2558
|
-
warning string was inserted. You need to rewrite this code.
|
|
2559
|
-
*/
|
|
2560
2412
|
SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
|
|
2561
2413
|
{*src0_extra->events[ctx.device][0]})));
|
|
2562
2414
|
}
|
|
@@ -2582,39 +2434,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2582
2434
|
// copy src0, src1 to device if necessary
|
|
2583
2435
|
if (src1_is_contiguous) {
|
|
2584
2436
|
if (i != ctx.device) {
|
|
2585
|
-
if (
|
|
2437
|
+
if constexpr (quantize_enabled) {
|
|
2586
2438
|
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
|
|
2587
|
-
|
|
2588
|
-
|
|
2589
|
-
|
|
2590
|
-
|
|
2439
|
+
SYCL_CHECK(
|
|
2440
|
+
CHECK_TRY_ERROR(stream
|
|
2441
|
+
->memcpy(src1_ddq_i, src1_ddq_i_source,
|
|
2442
|
+
src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
|
|
2443
|
+
.wait()));
|
|
2591
2444
|
} else {
|
|
2592
|
-
|
|
2593
2445
|
float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
|
|
2594
|
-
src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
|
|
2446
|
+
src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
|
|
2595
2447
|
|
|
2596
|
-
SYCL_CHECK(
|
|
2597
|
-
src1_ddf_i, src1_ddf_i_source,
|
|
2598
|
-
|
|
2448
|
+
SYCL_CHECK(
|
|
2449
|
+
CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
|
|
2450
|
+
src1_ncols * ne10 * sizeof(float))));
|
|
2599
2451
|
}
|
|
2600
2452
|
}
|
|
2601
|
-
} else if (src1_on_device && !src1_is_contiguous) {
|
|
2602
|
-
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
|
|
2603
|
-
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
|
|
2604
2453
|
} else {
|
|
2605
|
-
|
|
2606
|
-
|
|
2454
|
+
if (src1_on_device) {
|
|
2455
|
+
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
|
|
2456
|
+
src1_col_0 + src1_ncols, stream));
|
|
2457
|
+
} else {
|
|
2458
|
+
GGML_ABORT("src1 is non-contiguous and not on device");
|
|
2459
|
+
}
|
|
2607
2460
|
|
|
2608
|
-
|
|
2609
|
-
|
|
2610
|
-
|
|
2611
|
-
|
|
2612
|
-
|
|
2613
|
-
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2617
|
-
|
|
2461
|
+
if constexpr (quantize_enabled) {
|
|
2462
|
+
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
|
|
2463
|
+
/*num_src=*/2, " : converting src1 to Q8_1");
|
|
2464
|
+
try {
|
|
2465
|
+
quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
|
|
2466
|
+
src1_padded_col_size, stream);
|
|
2467
|
+
} catch (const sycl::exception & exc) {
|
|
2468
|
+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
|
|
2469
|
+
<< "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
|
2470
|
+
std::exit(1);
|
|
2471
|
+
}
|
|
2472
|
+
}
|
|
2618
2473
|
}
|
|
2619
2474
|
|
|
2620
2475
|
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
|
|
@@ -2626,12 +2481,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2626
2481
|
// do the computation
|
|
2627
2482
|
SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
|
|
2628
2483
|
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
|
|
2629
|
-
/*
|
|
2630
|
-
DPCT1010:93: SYCL uses exceptions to report errors and does not
|
|
2631
|
-
use the error codes. The call was replaced with 0. You need to
|
|
2632
|
-
rewrite this code.
|
|
2633
|
-
*/
|
|
2634
|
-
SYCL_CHECK(0);
|
|
2635
2484
|
|
|
2636
2485
|
// copy dst to host or other device if necessary
|
|
2637
2486
|
if (!dst_on_device) {
|
|
@@ -2662,12 +2511,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
|
2662
2511
|
|
|
2663
2512
|
// add event for the main device to wait on until other device is done
|
|
2664
2513
|
if (split && (i != ctx.device || is != 0)) {
|
|
2665
|
-
/*
|
|
2666
|
-
DPCT1024:94: The original code returned the error code that
|
|
2667
|
-
was further consumed by the program logic. This original
|
|
2668
|
-
code was replaced with 0. You may need to rewrite the
|
|
2669
|
-
program logic consuming the error code.
|
|
2670
|
-
*/
|
|
2671
2514
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
|
2672
2515
|
*src0_extra->events[i][is] =
|
|
2673
2516
|
stream->ext_oneapi_submit_barrier()));
|
|
@@ -2766,6 +2609,8 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
|
|
2766
2609
|
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
|
2767
2610
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
2768
2611
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
2612
|
+
GGML_ASSERT(src1->ne[1] == 1);
|
|
2613
|
+
GGML_ASSERT(src1->ne[3] == 1);
|
|
2769
2614
|
|
|
2770
2615
|
const int64_t ne00 = src0->ne[0];
|
|
2771
2616
|
const int64_t ne01 = src0->ne[1];
|
|
@@ -2845,6 +2690,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2845
2690
|
const size_t type_size_src0 = ggml_type_size(src0->type);
|
|
2846
2691
|
const size_t type_size_src1 = ggml_type_size(src1->type);
|
|
2847
2692
|
|
|
2693
|
+
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
|
|
2694
|
+
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
|
|
2695
|
+
|
|
2848
2696
|
// SRC1 strides
|
|
2849
2697
|
int64_t s11 = nb11 / type_size_src1;
|
|
2850
2698
|
int64_t s12 = nb12 / type_size_src1;
|
|
@@ -2857,9 +2705,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2857
2705
|
" : converting src1 to fp16");
|
|
2858
2706
|
|
|
2859
2707
|
// iterate tensor dims and find the slowest moving dim and stride
|
|
2860
|
-
|
|
2861
|
-
|
|
2862
|
-
|
|
2708
|
+
int last_dim=0;
|
|
2709
|
+
int last_str=0;
|
|
2710
|
+
size_t largest_str=0;
|
|
2863
2711
|
for(int i = 0; i< 4; i++){
|
|
2864
2712
|
// last stride is always the largest
|
|
2865
2713
|
if(src1->nb[i] == largest_str){
|
|
@@ -2894,6 +2742,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2894
2742
|
s11 = ne10;
|
|
2895
2743
|
s12 = ne11 * s11;
|
|
2896
2744
|
s13 = ne12 * s12;
|
|
2745
|
+
|
|
2746
|
+
is_src1_cont_2 = true;
|
|
2897
2747
|
}
|
|
2898
2748
|
|
|
2899
2749
|
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
|
@@ -2933,7 +2783,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2933
2783
|
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
|
|
2934
2784
|
const sycl::half *src1, float *dst,
|
|
2935
2785
|
int64_t a0, int64_t a1, int64_t batcha,
|
|
2936
|
-
int64_t b0
|
|
2786
|
+
int64_t /*b0*/, int64_t b1, int64_t batchb,
|
|
2937
2787
|
int64_t sa0, int64_t sa1, int64_t sa2,
|
|
2938
2788
|
int64_t sb0, int64_t sb1, int64_t sb2,
|
|
2939
2789
|
int64_t sd2) {
|
|
@@ -2982,14 +2832,26 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
2982
2832
|
}
|
|
2983
2833
|
};
|
|
2984
2834
|
|
|
2985
|
-
bool
|
|
2986
|
-
bool
|
|
2987
|
-
|
|
2835
|
+
const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
|
|
2836
|
+
const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
|
|
2837
|
+
const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
|
|
2838
|
+
const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
|
|
2839
|
+
if (cont_batches_dim2_a && cont_batches_dim2_b) {
|
|
2840
|
+
// A batch is considered contiguous if the dimension 2 is not strided
|
|
2988
2841
|
int64_t batches0 = ne02 * ne03;
|
|
2989
2842
|
int64_t batches1 = ne12 * ne13;
|
|
2990
2843
|
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
|
|
2991
2844
|
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
|
|
2992
2845
|
str_b2, nb2 / sizeof(float));
|
|
2846
|
+
} else if (cont_batches_dim3_a && cont_batches_dim3_b) {
|
|
2847
|
+
// This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
|
|
2848
|
+
int64_t batches0 = ne02 * ne03;
|
|
2849
|
+
int64_t batches1 = ne12 * ne13;
|
|
2850
|
+
int64_t str_a3 = nb03 / type_size_src0;
|
|
2851
|
+
int64_t str_b3 = nb13 / type_size_src1;
|
|
2852
|
+
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
|
|
2853
|
+
ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
|
|
2854
|
+
str_b3, nb2 / sizeof(float));
|
|
2993
2855
|
} else {
|
|
2994
2856
|
for (int64_t b_a = 0; b_a < ne03; b_a++) {
|
|
2995
2857
|
const sycl::half *src0_f16_shifted
|
|
@@ -3009,12 +2871,16 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
|
3009
2871
|
else
|
|
3010
2872
|
#endif
|
|
3011
2873
|
{
|
|
3012
|
-
if (r2 == 1 && r3 == 1 &&
|
|
2874
|
+
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
|
|
2875
|
+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
|
|
2876
|
+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
|
|
2877
|
+
const int64_t smb = ne12 == 1 ? s13 : s12;
|
|
2878
|
+
|
|
3013
2879
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
|
3014
2880
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
|
3015
2881
|
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
|
3016
|
-
src0_f16, dpct::library_data_t::real_half, nb01 / nb00,
|
|
3017
|
-
src1_f16, dpct::library_data_t::real_half, s11,
|
|
2882
|
+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
|
|
2883
|
+
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
|
|
3018
2884
|
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
|
3019
2885
|
} else {
|
|
3020
2886
|
const int ne23 = ne12 * ne13;
|
|
@@ -3344,26 +3210,27 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
|
3344
3210
|
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
|
3345
3211
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
|
3346
3212
|
}
|
|
3347
|
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
|
3213
|
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
|
|
3348
3214
|
// KQV single-batch
|
|
3349
3215
|
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
|
3350
3216
|
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
|
|
3351
3217
|
// KQ + KQV multi-batch
|
|
3352
3218
|
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
|
3353
3219
|
} else if (use_dequantize_mul_mat_vec) {
|
|
3354
|
-
constexpr bool convert_src1_to_q8_1 = false;
|
|
3355
3220
|
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
|
|
3356
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec
|
|
3221
|
+
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
|
|
3357
3222
|
} else if (use_mul_mat_vec_q) {
|
|
3358
|
-
constexpr bool convert_src1_to_q8_1 = true;
|
|
3359
3223
|
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
|
|
3360
|
-
|
|
3224
|
+
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
|
|
3225
|
+
if (extra && extra->optimized_feature.reorder) {
|
|
3226
|
+
ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
|
|
3227
|
+
} else {
|
|
3228
|
+
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
|
|
3229
|
+
}
|
|
3361
3230
|
} else if (use_mul_mat_q) {
|
|
3362
|
-
|
|
3363
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
|
3231
|
+
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
|
|
3364
3232
|
} else {
|
|
3365
|
-
|
|
3366
|
-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
|
3233
|
+
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
|
|
3367
3234
|
}
|
|
3368
3235
|
}
|
|
3369
3236
|
|
|
@@ -4338,15 +4205,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4338
4205
|
case GGML_OP_MUL_MAT:
|
|
4339
4206
|
case GGML_OP_MUL_MAT_ID:
|
|
4340
4207
|
{
|
|
4341
|
-
struct ggml_tensor * a;
|
|
4342
|
-
struct ggml_tensor * b;
|
|
4343
|
-
|
|
4344
|
-
a = op->src[0];
|
|
4345
|
-
b = op->src[1];
|
|
4346
|
-
} else {
|
|
4347
|
-
a = op->src[2];
|
|
4348
|
-
b = op->src[1];
|
|
4349
|
-
}
|
|
4208
|
+
struct ggml_tensor * a = op->src[0];
|
|
4209
|
+
struct ggml_tensor * b = op->src[1];
|
|
4210
|
+
|
|
4350
4211
|
if (a->ne[3] != b->ne[3]) {
|
|
4351
4212
|
return false;
|
|
4352
4213
|
}
|
|
@@ -4361,7 +4222,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4361
4222
|
}
|
|
4362
4223
|
}
|
|
4363
4224
|
ggml_type src0_type = op->src[0]->type;
|
|
4364
|
-
if (src0_type == GGML_TYPE_BF16) {
|
|
4225
|
+
if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
|
|
4226
|
+
// TODO: support MXFP4
|
|
4227
|
+
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
|
|
4228
|
+
return false;
|
|
4229
|
+
}
|
|
4230
|
+
// TODO: The configuration below needs more work to be supported with oneDNN
|
|
4231
|
+
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
|
|
4232
|
+
return false;
|
|
4233
|
+
}
|
|
4234
|
+
// TODO: This specific configuration can fail with oneDNN and needs more debugging
|
|
4235
|
+
if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
|
|
4236
|
+
a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
|
|
4365
4237
|
return false;
|
|
4366
4238
|
}
|
|
4367
4239
|
return true;
|
|
@@ -4385,11 +4257,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4385
4257
|
}
|
|
4386
4258
|
case GGML_OP_SET_ROWS:
|
|
4387
4259
|
{
|
|
4388
|
-
|
|
4389
|
-
|
|
4390
|
-
|
|
4391
|
-
|
|
4392
|
-
}
|
|
4260
|
+
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
|
4261
|
+
op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
|
|
4262
|
+
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
|
|
4263
|
+
(op->src[1]->type == GGML_TYPE_I64));
|
|
4264
|
+
}
|
|
4265
|
+
break;
|
|
4393
4266
|
case GGML_OP_CPY:
|
|
4394
4267
|
{
|
|
4395
4268
|
ggml_type src0_type = op->src[0]->type;
|
|
@@ -4491,11 +4364,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4491
4364
|
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
|
4492
4365
|
#endif
|
|
4493
4366
|
case GGML_OP_NORM:
|
|
4494
|
-
case GGML_OP_RMS_NORM:
|
|
4495
4367
|
return true;
|
|
4496
4368
|
case GGML_OP_L2_NORM:
|
|
4497
4369
|
case GGML_OP_GROUP_NORM:
|
|
4498
4370
|
return ggml_is_contiguous(op->src[0]);
|
|
4371
|
+
case GGML_OP_RMS_NORM:
|
|
4372
|
+
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
|
4499
4373
|
case GGML_OP_SCALE:
|
|
4500
4374
|
return true;
|
|
4501
4375
|
case GGML_OP_CONT:
|
|
@@ -4505,6 +4379,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4505
4379
|
if (op->src[0]->ne[3] != 1) {
|
|
4506
4380
|
return false;
|
|
4507
4381
|
}
|
|
4382
|
+
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
|
4383
|
+
if (op->src[2]) {
|
|
4384
|
+
return false;
|
|
4385
|
+
}
|
|
4508
4386
|
// TODO: support broadcast
|
|
4509
4387
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
|
4510
4388
|
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
|
@@ -4514,10 +4392,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
4514
4392
|
return true;
|
|
4515
4393
|
case GGML_OP_UPSCALE:
|
|
4516
4394
|
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
4517
|
-
case GGML_OP_POOL_2D:
|
|
4518
4395
|
case GGML_OP_SUM:
|
|
4519
4396
|
case GGML_OP_SUM_ROWS:
|
|
4520
4397
|
case GGML_OP_ARGSORT:
|
|
4398
|
+
return ggml_is_contiguous(op->src[0]);
|
|
4399
|
+
case GGML_OP_POOL_2D:
|
|
4521
4400
|
case GGML_OP_ACC:
|
|
4522
4401
|
case GGML_OP_PAD:
|
|
4523
4402
|
case GGML_OP_LEAKY_RELU:
|
|
@@ -4730,10 +4609,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) {
|
|
|
4730
4609
|
};
|
|
4731
4610
|
|
|
4732
4611
|
ggml_backend_t sycl_backend = new ggml_backend {
|
|
4733
|
-
/* .guid
|
|
4734
|
-
/* .
|
|
4735
|
-
/* .device
|
|
4736
|
-
/* .context
|
|
4612
|
+
/* .guid = */ ggml_backend_sycl_guid(),
|
|
4613
|
+
/* .iface = */ ggml_backend_sycl_interface,
|
|
4614
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
|
|
4615
|
+
/* .context = */ ctx
|
|
4737
4616
|
};
|
|
4738
4617
|
|
|
4739
4618
|
return sycl_backend;
|