@novastera-oss/llamarn 0.2.5 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +140 -38
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +48 -67
  13. package/cpp/LlamaCppModel.h +8 -3
  14. package/cpp/PureCppImpl.cpp +1 -1
  15. package/cpp/PureCppImpl.h +2 -2
  16. package/cpp/build-info.cpp +2 -2
  17. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  18. package/cpp/llama.cpp/Makefile +2 -2
  19. package/cpp/llama.cpp/README.md +33 -13
  20. package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
  21. package/cpp/llama.cpp/common/arg.cpp +38 -12
  22. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  23. package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
  24. package/cpp/llama.cpp/common/chat-parser.h +4 -1
  25. package/cpp/llama.cpp/common/chat.cpp +16 -13
  26. package/cpp/llama.cpp/common/chat.h +1 -1
  27. package/cpp/llama.cpp/common/common.cpp +52 -40
  28. package/cpp/llama.cpp/common/common.h +5 -2
  29. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  30. package/cpp/llama.cpp/common/json-partial.h +2 -1
  31. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  32. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  33. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  34. package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  37. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  38. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
  39. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  41. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  79. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  82. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  83. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  112. package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
  113. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  114. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  115. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  116. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  117. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  118. package/cpp/llama.cpp/include/llama.h +140 -38
  119. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  120. package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
  121. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  122. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  123. package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
  124. package/cpp/llama.cpp/src/llama-batch.h +47 -17
  125. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  126. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  127. package/cpp/llama.cpp/src/llama-context.cpp +488 -313
  128. package/cpp/llama.cpp/src/llama-context.h +38 -17
  129. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  130. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  131. package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
  132. package/cpp/llama.cpp/src/llama-graph.h +109 -52
  133. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  134. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
  139. package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  141. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  142. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
  144. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  145. package/cpp/llama.cpp/src/llama-memory.h +89 -4
  146. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  147. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  148. package/cpp/llama.cpp/src/llama-model.cpp +735 -143
  149. package/cpp/llama.cpp/src/llama-model.h +4 -0
  150. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  151. package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
  152. package/cpp/llama.cpp/src/llama.cpp +11 -7
  153. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  154. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  155. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  156. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  157. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  158. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  159. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  160. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  161. package/cpp/rn-completion.cpp +65 -10
  162. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  163. package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
  164. package/ios/include/chat.h +1 -1
  165. package/ios/include/common/minja/chat-template.hpp +1 -1
  166. package/ios/include/common/minja/minja.hpp +1 -1
  167. package/ios/include/common.h +5 -2
  168. package/ios/include/json-schema-to-grammar.h +4 -4
  169. package/ios/include/llama.h +140 -38
  170. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  171. package/ios/libs/llama.xcframework/Info.plist +20 -20
  172. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
  174. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  175. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
  176. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  177. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  178. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  179. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
  180. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  181. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  182. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  184. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  185. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
  186. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  187. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
  188. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  189. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
  190. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  191. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  192. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
  193. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  194. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  195. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  196. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
  197. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  198. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
  199. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
  202. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
  203. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  204. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  205. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  206. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  207. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
  208. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  209. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
  210. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  211. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  212. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
  213. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
  214. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  215. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  216. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  217. package/package.json +1 -2
  218. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  219. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  221. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
  222. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
  223. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  224. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  225. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -49,10 +49,7 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
49
49
 
50
50
  if (i0 >= n_dims) {
51
51
  const int i = row * ne0 + i0;
52
-
53
- dst[i + 0] = x[i + 0];
54
- dst[i + 1] = x[i + 1];
55
-
52
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
56
53
  return;
57
54
  }
58
55
 
@@ -93,10 +90,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
93
90
 
94
91
  if (i0 >= n_dims) {
95
92
  const int i = row * ne0 + i0;
96
-
97
- dst[i + 0] = x[i + 0];
98
- dst[i + 1] = x[i + 1];
99
-
93
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
100
94
  return;
101
95
  }
102
96
 
@@ -122,6 +116,63 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
122
116
  dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
123
117
  }
124
118
 
119
+ template <typename T, bool has_ff>
120
+ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
121
+ const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
122
+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
123
+ const float theta_scale, const float * freq_factors, const mrope_sections sections,
124
+ const sycl::nd_item<3> & item_ct1) {
125
+ // get index pos
126
+ const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
127
+ if (i0 >= ne0) {
128
+ return;
129
+ }
130
+ const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
131
+
132
+ if (i0 >= n_dims) {
133
+ const int i = row_dst*ne0 + i0;
134
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
135
+ return;
136
+ }
137
+
138
+ const int row_x = row_dst % ne1;
139
+ const int channel_x = row_dst / ne1;
140
+ const int idst = (row_dst * ne0) + (i0 / 2);
141
+ const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
142
+
143
+ const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
144
+ const int sec_w = sections.v[1] + sections.v[0];
145
+ const int sector = (i0 / 2) % sect_dims;
146
+
147
+
148
+ float theta_base = 0.0;
149
+ if (sector < sections.v[0]) {
150
+ theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
151
+ }
152
+ else if (sector >= sections.v[0] && sector < sec_w) {
153
+ theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
154
+ }
155
+ else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
156
+ theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
157
+ }
158
+ else if (sector >= sec_w + sections.v[2]) {
159
+ theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
160
+ }
161
+
162
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
163
+ float cos_theta;
164
+ float sin_theta;
165
+ rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
166
+ const float x0 = x[ix + 0];
167
+ const float x1 = x[ix + n_dims/2];
168
+
169
+ // store results in dst
170
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
171
+ dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
172
+ }
173
+
174
+
175
+
125
176
  template <typename T, bool has_ff>
126
177
  static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
127
178
  const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
@@ -171,7 +222,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
171
222
  const float * freq_factors, queue_ptr stream) {
172
223
  GGML_ASSERT(ne0 % 2 == 0);
173
224
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
174
- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
225
+ const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
175
226
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
176
227
 
177
228
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
@@ -208,7 +259,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
208
259
  const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
209
260
  GGML_ASSERT(ne0 % 2 == 0);
210
261
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
211
- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
262
+ const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
212
263
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
213
264
 
214
265
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
@@ -228,6 +279,40 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
228
279
  }
229
280
  }
230
281
 
282
+ template <typename T>
283
+ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
284
+ const size_t s2, const int n_dims, const int nr, const int32_t * pos,
285
+ const float freq_scale, const float freq_base, const float ext_factor,
286
+ const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
287
+ const mrope_sections sections, queue_ptr stream) {
288
+ GGML_ASSERT(ne0 % 2 == 0);
289
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
290
+ const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
291
+ const sycl::range<3> grid_dims(1, n_blocks_y, nr);
292
+ const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
293
+
294
+ const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
295
+ // Add FP16 capability check if T could be sycl::half
296
+ if constexpr (std::is_same_v<T, sycl::half>) {
297
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
298
+ }
299
+ // launch kernel
300
+ if (freq_factors == nullptr) {
301
+ stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
302
+ rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
303
+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
304
+ });
305
+ } else {
306
+ stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
307
+ rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
308
+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
309
+ });
310
+ }
311
+ }
312
+
313
+
314
+
315
+
231
316
  // rope vision
232
317
  template <typename T>
233
318
  static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
@@ -237,7 +322,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
237
322
  const mrope_sections sections, queue_ptr stream) {
238
323
  GGML_ASSERT(ne0 % 2 == 0);
239
324
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
240
- const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
325
+ const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
241
326
  const sycl::range<3> grid_dims(1, n_blocks_y, nr);
242
327
  const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
243
328
 
@@ -298,8 +383,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
298
383
  memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
299
384
 
300
385
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
386
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
301
387
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
302
388
 
389
+ if (is_mrope) {
390
+ GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
391
+ }
392
+
393
+ if (is_vision) {
394
+ GGML_ASSERT(n_dims == ne00/2);
395
+ }
396
+
303
397
  const int32_t * pos = (const int32_t *) dst->src[1]->data;
304
398
 
305
399
  const float * freq_factors = nullptr;
@@ -326,6 +420,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
326
420
  } else {
327
421
  GGML_ABORT("fatal error");
328
422
  }
423
+ } else if (is_mrope && !is_vision) {
424
+ GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
425
+ if (dst->src[0]->type == GGML_TYPE_F16) {
426
+ rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
427
+ s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
428
+ freq_factors, sections, main_stream);
429
+ } else if (dst->src[0]->type == GGML_TYPE_F32) {
430
+ rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
431
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
432
+ main_stream);
433
+ } else {
434
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
435
+ }
329
436
  } else if (is_vision) {
330
437
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
331
438
  if (dst->src[0]->type == GGML_TYPE_F16) {
@@ -284,22 +284,23 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
284
284
  return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
285
285
  }
286
286
 
287
- __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288
- const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
289
- const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290
- const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
287
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
288
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
289
+ const sycl::half2 * q8_1_ds, const int & iqs) {
290
+ const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;
291
+ const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
291
292
  int v[q4_0_traits::vdr_mmvq];
292
293
  int u[2 * q4_0_traits::vdr_mmvq];
293
294
 
294
- #pragma unroll
295
295
 
296
+ #pragma unroll
296
297
  for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
297
298
  v[i] = get_int_from_uint8(bq4_0, iqs + i);
298
- u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
299
- u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
299
+ u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
300
+ u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi);
300
301
  }
301
302
 
302
- return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
303
+ return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds);
303
304
  };
304
305
  };
305
306
 
@@ -346,24 +347,115 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
346
347
  using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
347
348
  using q4_k_traits = typename q4_k_block::traits;
348
349
 
349
- float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350
- const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
351
- const int ib = ibx_offset / (QK_K / 2);
350
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
351
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
352
+ const sycl::half2 * q8_1_ds, const int & iqs) {
353
+ const int ib = ibx_offset.first / (QK_K / 2);
352
354
 
353
355
  const uint8_t * base = static_cast<const uint8_t *>(vbq);
354
- const uint8_t * qs = base + ibx_offset;
355
- const int total_qs_bytes = nblocks * (QK_K / 2);
356
- const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
357
- const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
356
+ const uint8_t * qs = base + ibx_offset.first;
357
+ const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE;
358
+ const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
358
359
 
359
360
  const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
360
361
  const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
361
362
  const uint16_t * scales = (const uint16_t *) scs;
362
363
 
363
- return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
364
+ int v[2];
365
+ int u[2 * QR4_K];
366
+ float d8[QR4_K];
367
+
368
+ v[0] = q4[0];
369
+ v[1] = q4[4];
370
+
371
+ uint16_t aux[2];
372
+ const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
373
+ if (j < 2) {
374
+ aux[0] = scales[j + 0] & 0x3f3f;
375
+ aux[1] = scales[j + 2] & 0x3f3f;
376
+ } else {
377
+ aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
378
+ aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
379
+ }
380
+
381
+ const uint8_t * sc = (const uint8_t *) aux;
382
+ const uint8_t * m = sc + 2;
383
+
384
+ for (int i = 0; i < QR4_K; ++i) {
385
+ const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
386
+ sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
387
+
388
+ d8[i] = ds_values[0];
389
+
390
+ const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);
391
+ u[2 * i + 0] = q8[0];
392
+ u[2 * i + 1] = q8[4];
393
+ }
394
+
395
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8);
364
396
  }
365
397
  };
366
398
 
399
+ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
400
+ static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
401
+
402
+ using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;
403
+ using q6_k_traits = typename q6_k_block::traits;
404
+
405
+ __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
406
+ const int8_t * __restrict__ scales, const float d,
407
+ const float * __restrict__ d8) {
408
+ float sumf = 0.0f;
409
+
410
+ #pragma unroll
411
+ for (int i = 0; i < QR6_K; ++i) {
412
+ const int sc = scales[4 * i];
413
+
414
+ const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
415
+
416
+ const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
417
+
418
+ const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
419
+ dpct::sub_sat()); // vi = (vil | vih) - 32
420
+
421
+ sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
422
+ }
423
+
424
+ return d * sumf;
425
+ }
426
+
427
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
428
+ const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,
429
+ const int iqs) {
430
+ const int ib = ibx_offset.first / (QK_K / 2);
431
+
432
+ const uint8_t * base = static_cast<const uint8_t *>(vbq);
433
+ const uint8_t * ql = base + ibx_offset.first;
434
+ const uint8_t * qh = base + ibx_offset.second;
435
+ const int8_t * scales = reinterpret_cast<const int8_t *>(base + d_offset.first);
436
+ const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib;
437
+
438
+ const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
439
+ const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
440
+ const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
441
+
442
+ const int vl = get_int_from_uint8(ql, iqs);
443
+ const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
444
+
445
+ const int8_t * scs = scales + scale_offset;
446
+
447
+ int u[QR6_K];
448
+ float d8[QR6_K];
449
+
450
+ #pragma unroll
451
+ for (int i = 0; i < QR6_K; ++i) {
452
+ u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
453
+ const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
454
+ d8[i] = ds_values[0];
455
+ }
456
+ return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
457
+ }
458
+ };
367
459
  #define VDR_Q4_0_Q8_1_MMVQ 2
368
460
  #define VDR_Q4_0_Q8_1_MMQ 4
369
461
 
@@ -49,15 +49,7 @@ if (Vulkan_FOUND)
49
49
  ../../include/ggml-vulkan.h
50
50
  )
51
51
 
52
- set(VULKAN_SHADER_GEN_CMAKE_ARGS
53
- -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
54
- -DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
55
- )
56
-
57
- set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "")
58
- if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo")
59
- list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE})
60
- endif()
52
+ set(VULKAN_SHADER_GEN_CMAKE_ARGS "")
61
53
 
62
54
  # Test all shader extensions
63
55
  test_shader_extension_support(
@@ -136,42 +128,45 @@ if (Vulkan_FOUND)
136
128
  set(HOST_CMAKE_TOOLCHAIN_FILE "")
137
129
  endif()
138
130
 
139
- # Always use ExternalProject_Add approach
140
131
  include(ExternalProject)
141
132
 
142
- # Add toolchain file if cross-compiling
143
133
  if (CMAKE_CROSSCOMPILING)
144
134
  list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
145
135
  message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
146
136
  endif()
147
137
 
148
- # Native build through ExternalProject_Add
149
138
  ExternalProject_Add(
150
139
  vulkan-shaders-gen
151
140
  SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
152
- CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS}
153
- BUILD_COMMAND ${CMAKE_COMMAND} --build . ${VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS}
154
- INSTALL_COMMAND ${CMAKE_COMMAND} --install .
155
- INSTALL_DIR ${CMAKE_BINARY_DIR}
141
+ CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$<CONFIG>
142
+ -DCMAKE_INSTALL_BINDIR=.
143
+ -DCMAKE_BUILD_TYPE=$<CONFIG>
144
+ ${VULKAN_SHADER_GEN_CMAKE_ARGS}
145
+
146
+ BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $<CONFIG>
147
+
148
+ # NOTE: When DESTDIR is set using Makefile generators and
149
+ # "make install" triggers the build step, vulkan-shaders-gen
150
+ # would be installed into the DESTDIR prefix, so it is unset
151
+ # to ensure that does not happen.
152
+
153
+ INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR
154
+ ${CMAKE_COMMAND} --install . --config $<CONFIG>
156
155
  )
157
- ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
158
156
 
159
157
  set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
160
- set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
161
- set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
162
- set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
163
- set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
164
- set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
158
+ set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
159
+ set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
160
+ set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
161
+ set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp")
162
+ set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
163
+ set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
165
164
 
166
- file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
167
- set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
168
-
169
- # Add build and install dependencies for all builds
170
- set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
165
+ file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp")
171
166
 
172
167
  add_custom_command(
173
168
  OUTPUT ${_ggml_vk_header}
174
- ${_ggml_vk_source}
169
+ ${_ggml_vk_source}
175
170
 
176
171
  COMMAND ${_ggml_vk_genshaders_cmd}
177
172
  --glslc ${Vulkan_GLSLC_EXECUTABLE}
@@ -181,7 +176,9 @@ if (Vulkan_FOUND)
181
176
  --target-cpp ${_ggml_vk_source}
182
177
  --no-clean
183
178
 
184
- DEPENDS ${_ggml_vk_shader_deps}
179
+ DEPENDS ${_ggml_vk_shader_files}
180
+ vulkan-shaders-gen
181
+
185
182
  COMMENT "Generate vulkan shaders"
186
183
  )
187
184