@novastera-oss/llamarn 0.2.5 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/RNLlamaCpp.podspec +3 -2
- package/android/CMakeLists.txt +6 -3
- package/android/src/main/cpp/include/llama.h +140 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +48 -67
- package/cpp/LlamaCppModel.h +8 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +33 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
- package/cpp/llama.cpp/common/arg.cpp +38 -12
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
- package/cpp/llama.cpp/common/chat-parser.h +4 -1
- package/cpp/llama.cpp/common/chat.cpp +16 -13
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +52 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/json-partial.cpp +5 -4
- package/cpp/llama.cpp/common/json-partial.h +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
- package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +140 -38
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
- package/cpp/llama.cpp/src/llama-batch.h +47 -17
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +488 -313
- package/cpp/llama.cpp/src/llama-context.h +38 -17
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
- package/cpp/llama.cpp/src/llama-graph.h +109 -52
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
- package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +89 -4
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +735 -143
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
- package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
- package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
- package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
- package/cpp/rn-completion.cpp +65 -10
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common/minja/chat-template.hpp +1 -1
- package/ios/include/common/minja/minja.hpp +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/json-schema-to-grammar.h +4 -4
- package/ios/include/llama.h +140 -38
- package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
- package/ios/libs/llama.xcframework/Info.plist +20 -20
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -49,10 +49,7 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
|
|
49
49
|
|
|
50
50
|
if (i0 >= n_dims) {
|
|
51
51
|
const int i = row * ne0 + i0;
|
|
52
|
-
|
|
53
|
-
dst[i + 0] = x[i + 0];
|
|
54
|
-
dst[i + 1] = x[i + 1];
|
|
55
|
-
|
|
52
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
56
53
|
return;
|
|
57
54
|
}
|
|
58
55
|
|
|
@@ -93,10 +90,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|
|
93
90
|
|
|
94
91
|
if (i0 >= n_dims) {
|
|
95
92
|
const int i = row * ne0 + i0;
|
|
96
|
-
|
|
97
|
-
dst[i + 0] = x[i + 0];
|
|
98
|
-
dst[i + 1] = x[i + 1];
|
|
99
|
-
|
|
93
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
100
94
|
return;
|
|
101
95
|
}
|
|
102
96
|
|
|
@@ -122,6 +116,63 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|
|
122
116
|
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
|
123
117
|
}
|
|
124
118
|
|
|
119
|
+
template <typename T, bool has_ff>
|
|
120
|
+
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
|
121
|
+
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
|
122
|
+
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
|
123
|
+
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
|
124
|
+
const sycl::nd_item<3> & item_ct1) {
|
|
125
|
+
// get index pos
|
|
126
|
+
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
|
127
|
+
if (i0 >= ne0) {
|
|
128
|
+
return;
|
|
129
|
+
}
|
|
130
|
+
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
|
131
|
+
|
|
132
|
+
if (i0 >= n_dims) {
|
|
133
|
+
const int i = row_dst*ne0 + i0;
|
|
134
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
135
|
+
return;
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
const int row_x = row_dst % ne1;
|
|
139
|
+
const int channel_x = row_dst / ne1;
|
|
140
|
+
const int idst = (row_dst * ne0) + (i0 / 2);
|
|
141
|
+
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
|
142
|
+
|
|
143
|
+
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
|
144
|
+
const int sec_w = sections.v[1] + sections.v[0];
|
|
145
|
+
const int sector = (i0 / 2) % sect_dims;
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
float theta_base = 0.0;
|
|
149
|
+
if (sector < sections.v[0]) {
|
|
150
|
+
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
|
151
|
+
}
|
|
152
|
+
else if (sector >= sections.v[0] && sector < sec_w) {
|
|
153
|
+
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
|
154
|
+
}
|
|
155
|
+
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
|
156
|
+
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
|
157
|
+
}
|
|
158
|
+
else if (sector >= sec_w + sections.v[2]) {
|
|
159
|
+
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
163
|
+
float cos_theta;
|
|
164
|
+
float sin_theta;
|
|
165
|
+
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
166
|
+
const float x0 = x[ix + 0];
|
|
167
|
+
const float x1 = x[ix + n_dims/2];
|
|
168
|
+
|
|
169
|
+
// store results in dst
|
|
170
|
+
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
|
171
|
+
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
|
|
125
176
|
template <typename T, bool has_ff>
|
|
126
177
|
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
|
127
178
|
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
|
@@ -171,7 +222,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
|
171
222
|
const float * freq_factors, queue_ptr stream) {
|
|
172
223
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
173
224
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
174
|
-
const int num_blocks_x = (ne0
|
|
225
|
+
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
|
175
226
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
|
176
227
|
|
|
177
228
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
@@ -208,7 +259,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
|
208
259
|
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
|
209
260
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
210
261
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
211
|
-
const int num_blocks_x = (ne0
|
|
262
|
+
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
|
212
263
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
|
213
264
|
|
|
214
265
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
@@ -228,6 +279,40 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
|
228
279
|
}
|
|
229
280
|
}
|
|
230
281
|
|
|
282
|
+
template <typename T>
|
|
283
|
+
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
|
284
|
+
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
|
285
|
+
const float freq_scale, const float freq_base, const float ext_factor,
|
|
286
|
+
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
|
287
|
+
const mrope_sections sections, queue_ptr stream) {
|
|
288
|
+
GGML_ASSERT(ne0 % 2 == 0);
|
|
289
|
+
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
290
|
+
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
|
291
|
+
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
|
292
|
+
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
|
293
|
+
|
|
294
|
+
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
|
295
|
+
// Add FP16 capability check if T could be sycl::half
|
|
296
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
297
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
298
|
+
}
|
|
299
|
+
// launch kernel
|
|
300
|
+
if (freq_factors == nullptr) {
|
|
301
|
+
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
|
302
|
+
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
|
303
|
+
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
|
304
|
+
});
|
|
305
|
+
} else {
|
|
306
|
+
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
|
307
|
+
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
|
308
|
+
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
|
309
|
+
});
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
|
|
231
316
|
// rope vision
|
|
232
317
|
template <typename T>
|
|
233
318
|
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
|
@@ -237,7 +322,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
|
|
237
322
|
const mrope_sections sections, queue_ptr stream) {
|
|
238
323
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
239
324
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
240
|
-
const int n_blocks_y = (ne0
|
|
325
|
+
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
|
241
326
|
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
|
242
327
|
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
|
243
328
|
|
|
@@ -298,8 +383,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
|
298
383
|
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
299
384
|
|
|
300
385
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
386
|
+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
301
387
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
302
388
|
|
|
389
|
+
if (is_mrope) {
|
|
390
|
+
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
if (is_vision) {
|
|
394
|
+
GGML_ASSERT(n_dims == ne00/2);
|
|
395
|
+
}
|
|
396
|
+
|
|
303
397
|
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
|
304
398
|
|
|
305
399
|
const float * freq_factors = nullptr;
|
|
@@ -326,6 +420,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
|
326
420
|
} else {
|
|
327
421
|
GGML_ABORT("fatal error");
|
|
328
422
|
}
|
|
423
|
+
} else if (is_mrope && !is_vision) {
|
|
424
|
+
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
|
|
425
|
+
if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
426
|
+
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
|
427
|
+
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
428
|
+
freq_factors, sections, main_stream);
|
|
429
|
+
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
|
430
|
+
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
|
431
|
+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
|
432
|
+
main_stream);
|
|
433
|
+
} else {
|
|
434
|
+
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
|
435
|
+
}
|
|
329
436
|
} else if (is_vision) {
|
|
330
437
|
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
|
331
438
|
if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
@@ -284,22 +284,23 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
|
|
284
284
|
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
|
|
285
285
|
}
|
|
286
286
|
|
|
287
|
-
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
const
|
|
287
|
+
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
|
288
|
+
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
|
|
289
|
+
const sycl::half2 * q8_1_ds, const int & iqs) {
|
|
290
|
+
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;
|
|
291
|
+
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
|
|
291
292
|
int v[q4_0_traits::vdr_mmvq];
|
|
292
293
|
int u[2 * q4_0_traits::vdr_mmvq];
|
|
293
294
|
|
|
294
|
-
#pragma unroll
|
|
295
295
|
|
|
296
|
+
#pragma unroll
|
|
296
297
|
for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
|
|
297
298
|
v[i] = get_int_from_uint8(bq4_0, iqs + i);
|
|
298
|
-
u[2 * i + 0] = get_int_from_int8_aligned(
|
|
299
|
-
u[2 * i + 1] = get_int_from_int8_aligned(
|
|
299
|
+
u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
|
|
300
|
+
u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi);
|
|
300
301
|
}
|
|
301
302
|
|
|
302
|
-
return vec_dot_q4_0_q8_1_impl(v, u, d,
|
|
303
|
+
return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds);
|
|
303
304
|
};
|
|
304
305
|
};
|
|
305
306
|
|
|
@@ -346,24 +347,115 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
|
|
346
347
|
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
|
|
347
348
|
using q4_k_traits = typename q4_k_block::traits;
|
|
348
349
|
|
|
349
|
-
float operator()(const void * __restrict__ vbq, const int
|
|
350
|
-
|
|
351
|
-
|
|
350
|
+
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
|
351
|
+
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
|
|
352
|
+
const sycl::half2 * q8_1_ds, const int & iqs) {
|
|
353
|
+
const int ib = ibx_offset.first / (QK_K / 2);
|
|
352
354
|
|
|
353
355
|
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
|
354
|
-
const uint8_t * qs = base + ibx_offset;
|
|
355
|
-
const
|
|
356
|
-
const
|
|
357
|
-
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
|
|
356
|
+
const uint8_t * qs = base + ibx_offset.first;
|
|
357
|
+
const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE;
|
|
358
|
+
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
|
|
358
359
|
|
|
359
360
|
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
|
360
361
|
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
|
361
362
|
const uint16_t * scales = (const uint16_t *) scs;
|
|
362
363
|
|
|
363
|
-
|
|
364
|
+
int v[2];
|
|
365
|
+
int u[2 * QR4_K];
|
|
366
|
+
float d8[QR4_K];
|
|
367
|
+
|
|
368
|
+
v[0] = q4[0];
|
|
369
|
+
v[1] = q4[4];
|
|
370
|
+
|
|
371
|
+
uint16_t aux[2];
|
|
372
|
+
const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
|
|
373
|
+
if (j < 2) {
|
|
374
|
+
aux[0] = scales[j + 0] & 0x3f3f;
|
|
375
|
+
aux[1] = scales[j + 2] & 0x3f3f;
|
|
376
|
+
} else {
|
|
377
|
+
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
|
|
378
|
+
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
const uint8_t * sc = (const uint8_t *) aux;
|
|
382
|
+
const uint8_t * m = sc + 2;
|
|
383
|
+
|
|
384
|
+
for (int i = 0; i < QR4_K; ++i) {
|
|
385
|
+
const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
|
|
386
|
+
sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
|
|
387
|
+
|
|
388
|
+
d8[i] = ds_values[0];
|
|
389
|
+
|
|
390
|
+
const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);
|
|
391
|
+
u[2 * i + 0] = q8[0];
|
|
392
|
+
u[2 * i + 1] = q8[4];
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8);
|
|
364
396
|
}
|
|
365
397
|
};
|
|
366
398
|
|
|
399
|
+
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
|
|
400
|
+
static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
|
|
401
|
+
|
|
402
|
+
using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;
|
|
403
|
+
using q6_k_traits = typename q6_k_block::traits;
|
|
404
|
+
|
|
405
|
+
__dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
|
|
406
|
+
const int8_t * __restrict__ scales, const float d,
|
|
407
|
+
const float * __restrict__ d8) {
|
|
408
|
+
float sumf = 0.0f;
|
|
409
|
+
|
|
410
|
+
#pragma unroll
|
|
411
|
+
for (int i = 0; i < QR6_K; ++i) {
|
|
412
|
+
const int sc = scales[4 * i];
|
|
413
|
+
|
|
414
|
+
const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
|
|
415
|
+
|
|
416
|
+
const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
|
|
417
|
+
|
|
418
|
+
const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
|
|
419
|
+
dpct::sub_sat()); // vi = (vil | vih) - 32
|
|
420
|
+
|
|
421
|
+
sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
return d * sumf;
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
|
428
|
+
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,
|
|
429
|
+
const int iqs) {
|
|
430
|
+
const int ib = ibx_offset.first / (QK_K / 2);
|
|
431
|
+
|
|
432
|
+
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
|
433
|
+
const uint8_t * ql = base + ibx_offset.first;
|
|
434
|
+
const uint8_t * qh = base + ibx_offset.second;
|
|
435
|
+
const int8_t * scales = reinterpret_cast<const int8_t *>(base + d_offset.first);
|
|
436
|
+
const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib;
|
|
437
|
+
|
|
438
|
+
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
|
|
439
|
+
const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
|
|
440
|
+
const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
|
|
441
|
+
|
|
442
|
+
const int vl = get_int_from_uint8(ql, iqs);
|
|
443
|
+
const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
|
|
444
|
+
|
|
445
|
+
const int8_t * scs = scales + scale_offset;
|
|
446
|
+
|
|
447
|
+
int u[QR6_K];
|
|
448
|
+
float d8[QR6_K];
|
|
449
|
+
|
|
450
|
+
#pragma unroll
|
|
451
|
+
for (int i = 0; i < QR6_K; ++i) {
|
|
452
|
+
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
|
|
453
|
+
const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
|
|
454
|
+
d8[i] = ds_values[0];
|
|
455
|
+
}
|
|
456
|
+
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
|
|
457
|
+
}
|
|
458
|
+
};
|
|
367
459
|
#define VDR_Q4_0_Q8_1_MMVQ 2
|
|
368
460
|
#define VDR_Q4_0_Q8_1_MMQ 4
|
|
369
461
|
|
|
@@ -49,15 +49,7 @@ if (Vulkan_FOUND)
|
|
|
49
49
|
../../include/ggml-vulkan.h
|
|
50
50
|
)
|
|
51
51
|
|
|
52
|
-
set(VULKAN_SHADER_GEN_CMAKE_ARGS
|
|
53
|
-
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
|
|
54
|
-
-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "")
|
|
58
|
-
if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo")
|
|
59
|
-
list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE})
|
|
60
|
-
endif()
|
|
52
|
+
set(VULKAN_SHADER_GEN_CMAKE_ARGS "")
|
|
61
53
|
|
|
62
54
|
# Test all shader extensions
|
|
63
55
|
test_shader_extension_support(
|
|
@@ -136,42 +128,45 @@ if (Vulkan_FOUND)
|
|
|
136
128
|
set(HOST_CMAKE_TOOLCHAIN_FILE "")
|
|
137
129
|
endif()
|
|
138
130
|
|
|
139
|
-
# Always use ExternalProject_Add approach
|
|
140
131
|
include(ExternalProject)
|
|
141
132
|
|
|
142
|
-
# Add toolchain file if cross-compiling
|
|
143
133
|
if (CMAKE_CROSSCOMPILING)
|
|
144
134
|
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
|
|
145
135
|
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
|
|
146
136
|
endif()
|
|
147
137
|
|
|
148
|
-
# Native build through ExternalProject_Add
|
|
149
138
|
ExternalProject_Add(
|
|
150
139
|
vulkan-shaders-gen
|
|
151
140
|
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
|
152
|
-
CMAKE_ARGS
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
141
|
+
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$<CONFIG>
|
|
142
|
+
-DCMAKE_INSTALL_BINDIR=.
|
|
143
|
+
-DCMAKE_BUILD_TYPE=$<CONFIG>
|
|
144
|
+
${VULKAN_SHADER_GEN_CMAKE_ARGS}
|
|
145
|
+
|
|
146
|
+
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $<CONFIG>
|
|
147
|
+
|
|
148
|
+
# NOTE: When DESTDIR is set using Makefile generators and
|
|
149
|
+
# "make install" triggers the build step, vulkan-shaders-gen
|
|
150
|
+
# would be installed into the DESTDIR prefix, so it is unset
|
|
151
|
+
# to ensure that does not happen.
|
|
152
|
+
|
|
153
|
+
INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR
|
|
154
|
+
${CMAKE_COMMAND} --install . --config $<CONFIG>
|
|
156
155
|
)
|
|
157
|
-
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
|
|
158
156
|
|
|
159
157
|
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
|
|
160
|
-
set (
|
|
161
|
-
set (
|
|
162
|
-
set (
|
|
163
|
-
set (
|
|
164
|
-
set (
|
|
158
|
+
set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
|
|
159
|
+
set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
|
|
160
|
+
set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
|
|
161
|
+
set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp")
|
|
162
|
+
set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
|
|
163
|
+
set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
|
|
165
164
|
|
|
166
|
-
file(GLOB
|
|
167
|
-
set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
|
|
168
|
-
|
|
169
|
-
# Add build and install dependencies for all builds
|
|
170
|
-
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
|
|
165
|
+
file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp")
|
|
171
166
|
|
|
172
167
|
add_custom_command(
|
|
173
168
|
OUTPUT ${_ggml_vk_header}
|
|
174
|
-
|
|
169
|
+
${_ggml_vk_source}
|
|
175
170
|
|
|
176
171
|
COMMAND ${_ggml_vk_genshaders_cmd}
|
|
177
172
|
--glslc ${Vulkan_GLSLC_EXECUTABLE}
|
|
@@ -181,7 +176,9 @@ if (Vulkan_FOUND)
|
|
|
181
176
|
--target-cpp ${_ggml_vk_source}
|
|
182
177
|
--no-clean
|
|
183
178
|
|
|
184
|
-
DEPENDS ${
|
|
179
|
+
DEPENDS ${_ggml_vk_shader_files}
|
|
180
|
+
vulkan-shaders-gen
|
|
181
|
+
|
|
185
182
|
COMMENT "Generate vulkan shaders"
|
|
186
183
|
)
|
|
187
184
|
|