@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
8
8
 
9
9
  const int nthreads = item_ct1.get_local_range(2);
10
10
  const int nwarps = nthreads / WARP_SIZE;
11
- assert(nwarps % WARP_SIZE == 0);
12
11
  sycl::float2 mean_var = sycl::float2(0.f, 0.f);
13
12
 
14
13
  for (int col = tid; col < ncols; col += block_size) {
@@ -32,7 +31,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
32
31
  */
33
32
  item_ct1.barrier(sycl::access::fence_space::local_space);
34
33
  mean_var = 0.f;
35
- int nreduce = nwarps / WARP_SIZE;
34
+ size_t nreduce = nwarps / WARP_SIZE;
36
35
  for (size_t i = 0; i < nreduce; i += 1)
37
36
  {
38
37
  mean_var += s_sum[lane_id + i * WARP_SIZE];
@@ -55,9 +54,8 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
55
54
  int end = start + group_size;
56
55
  const int nthreads = item_ct1.get_local_range(2);
57
56
  const int nwarps = nthreads / WARP_SIZE;
58
- assert(nwarps % WARP_SIZE == 0);
59
57
  start += item_ct1.get_local_id(2);
60
- int nreduce = nwarps / WARP_SIZE;
58
+ size_t nreduce = nwarps / WARP_SIZE;
61
59
 
62
60
  if (end >= ne_elements) {
63
61
  end = ne_elements;
@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
144
142
  const int tid = item_ct1.get_local_id(2);
145
143
  const int nthreads = item_ct1.get_local_range(2);
146
144
  const int nwarps = nthreads / WARP_SIZE;
147
- assert(nwarps % WARP_SIZE == 0);
148
145
  float tmp = 0.0f; // partial sum for thread in warp
149
146
 
150
147
  for (int col = tid; col < ncols; col += block_size) {
@@ -166,7 +163,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
166
163
  converged control flow. You may need to adjust the code.
167
164
  */
168
165
  item_ct1.barrier(sycl::access::fence_space::local_space);
169
- int nreduce = nwarps / WARP_SIZE;
166
+ size_t nreduce = nwarps / WARP_SIZE;
170
167
  tmp = 0.f;
171
168
  for (size_t i = 0; i < nreduce; i += 1)
172
169
  {
@@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
202
199
  }
203
200
  else {
204
201
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
202
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
205
203
  const sycl::range<3> block_dims(1, 1, work_group_size);
206
204
  /*
207
205
  DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
244
242
  }
245
243
  else {
246
244
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
245
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
247
246
  const sycl::range<3> block_dims(1, 1, work_group_size);
248
247
  /*
249
248
  DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
290
289
  }
291
290
  else {
292
291
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
292
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
293
293
  const sycl::range<3> block_dims(1, 1, work_group_size);
294
294
  /*
295
295
  DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
@@ -352,6 +352,7 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
352
352
  (void)src1;
353
353
  (void)dst;
354
354
  (void)src1_dd;
355
+ GGML_UNUSED(ctx);
355
356
  }
356
357
 
357
358
  void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
@@ -0,0 +1,56 @@
1
+ #include <sycl/sycl.hpp>
2
+ #include <oneapi/mkl.hpp>
3
+ #include "outprod.hpp"
4
+
5
+
6
+ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
7
+ const ggml_tensor* src1, ggml_tensor* dst) {
8
+
9
+
10
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
11
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
12
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
13
+ GGML_ASSERT(ggml_is_contiguous(src0));
14
+ GGML_ASSERT(ggml_is_contiguous(dst));
15
+
16
+ GGML_TENSOR_BINARY_OP_LOCALS
17
+
18
+ // Get SYCL queue
19
+ dpct::queue_ptr stream = ctx.stream();
20
+
21
+ // Dimension checks
22
+ GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
23
+ GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
24
+ GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
25
+
26
+ // Get data pointers
27
+ const float* src0_d = (const float*)src0->data;
28
+ const float* src1_d = (const float*)src1->data;
29
+ float* dst_d = (float*)dst->data;
30
+
31
+ // GEMM parameters
32
+ const float alpha = 1.0f;
33
+ const float beta = 0.0f;
34
+
35
+ // Handle transposition of src1
36
+ const bool src1_T = ggml_is_transposed(src1);
37
+ const oneapi::mkl::transpose src1_op =
38
+ src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
39
+ const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
40
+
41
+ try {
42
+ // Perform matrix multiplication using oneMKL GEMM
43
+ #ifdef GGML_SYCL_NVIDIA
44
+ oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
45
+ oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
46
+ ne00, src1_d, ldb, beta, dst_d, ne0);
47
+ #else
48
+ oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
49
+ src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
50
+ #endif
51
+ }
52
+ catch (sycl::exception const& exc) {
53
+ std::cerr << exc.what() << std::endl;
54
+ GGML_ASSERT(false);
55
+ }
56
+ }
@@ -0,0 +1,11 @@
1
+ #ifndef GGML_SYCL_OUTPROD_HPP
2
+ #define GGML_SYCL_OUTPROD_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
7
+ const ggml_tensor* src1, ggml_tensor* dst);
8
+
9
+
10
+ #endif // GGML_SYCL_OUTPROD_HPP
11
+
@@ -25,6 +25,11 @@
25
25
  #define SYCL_RELU_BLOCK_SIZE 256
26
26
  #define SYCL_HARDSIGMOID_BLOCK_SIZE 256
27
27
  #define SYCL_HARDSWISH_BLOCK_SIZE 256
28
+ #define SYCL_EXP_BLOCK_SIZE 256
29
+ #define SYCL_NEG_BLOCK_SIZE 256
30
+ #define SYCL_SIGMOID_BLOCK_SIZE 256
31
+ #define SYCL_SQRT_BLOCK_SIZE 256
32
+ #define SYCL_SIN_BLOCK_SIZE 256
28
33
  #define SYCL_SQR_BLOCK_SIZE 256
29
34
  #define SYCL_CPY_BLOCK_SIZE 32
30
35
  #define SYCL_SCALE_BLOCK_SIZE 256
@@ -41,6 +46,7 @@
41
46
  #define SYCL_ACC_BLOCK_SIZE 256
42
47
  #define SYCL_IM2COL_BLOCK_SIZE 256
43
48
  #define SYCL_POOL2D_BLOCK_SIZE 256
49
+ #define SYCL_ARGMAX_BLOCK_SIZE 256
44
50
  #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
45
51
  #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
46
52
 
@@ -269,7 +269,8 @@ void ggml_sycl_op_rope(
269
269
  }
270
270
  }
271
271
 
272
- (void) src1;
273
- (void) dst;
274
- (void) src1_dd;
272
+ GGML_UNUSED(src1);
273
+ GGML_UNUSED(dst);
274
+ GGML_UNUSED(src1_dd);
275
+ GGML_UNUSED(ctx);
275
276
  }
@@ -16,7 +16,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
16
16
  const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
17
17
  const int nthreads = block_size;
18
18
  const int nwarps = nthreads / WARP_SIZE;
19
- int nreduce = nwarps / WARP_SIZE;
19
+ size_t nreduce = nwarps / WARP_SIZE;
20
20
  float slope = 1.0f;
21
21
 
22
22
  // ALiBi
@@ -53,8 +53,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
53
53
  if (block_size > WARP_SIZE) {
54
54
  if (warp_id == 0) {
55
55
  buf[lane_id] = -INFINITY;
56
- for (size_t i = 1; i < nreduce; i += 1)
56
+ for (size_t i = 1; i < nreduce; i += 1) {
57
57
  buf[lane_id + i * WARP_SIZE] = -INFINITY;
58
+ }
58
59
  }
59
60
  item_ct1.barrier(sycl::access::fence_space::local_space);
60
61
 
@@ -63,8 +64,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
63
64
  }
64
65
  item_ct1.barrier(sycl::access::fence_space::local_space);
65
66
  max_val = buf[lane_id];
66
- for (size_t i = 1; i < nreduce; i += 1)
67
- {
67
+ for (size_t i = 1; i < nreduce; i += 1) {
68
68
  max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
69
69
  }
70
70
  max_val = warp_reduce_max(max_val, item_ct1);
@@ -89,8 +89,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
89
89
  item_ct1.barrier(sycl::access::fence_space::local_space);
90
90
  if (warp_id == 0) {
91
91
  buf[lane_id] = 0.f;
92
- for (size_t i = 1; i < nreduce; i += 1)
92
+ for (size_t i = 1; i < nreduce; i += 1) {
93
93
  buf[lane_id + i * WARP_SIZE] = 0.f;
94
+ }
94
95
  }
95
96
  item_ct1.barrier(sycl::access::fence_space::local_space);
96
97
 
@@ -100,8 +101,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
100
101
  item_ct1.barrier(sycl::access::fence_space::local_space);
101
102
 
102
103
  tmp = buf[lane_id];
103
- for (size_t i = 1; i < nreduce; i += 1)
104
- {
104
+ for (size_t i = 1; i < nreduce; i += 1) {
105
105
  tmp += buf[lane_id + i * WARP_SIZE];
106
106
  }
107
107
  tmp = warp_reduce_sum(tmp, item_ct1);
@@ -68,4 +68,5 @@ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml
68
68
  const int max_period = dst->op_params[1];
69
69
 
70
70
  timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
71
+ GGML_UNUSED(src1);
71
72
  }
@@ -968,8 +968,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
968
968
  grid1[0] ^ signs[0], signs[0], std::minus<>());
969
969
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
970
970
  grid2[0] ^ signs[1], signs[1], std::minus<>());
971
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
972
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
971
+ sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
972
+ sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
973
973
  q8 += 8;
974
974
  aux32 >>= 7;
975
975
  }
@@ -1009,8 +1009,8 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
1009
1009
  grid1[0] ^ signs0, signs0, std::minus<>());
1010
1010
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
1011
1011
  grid2[0] ^ signs1, signs1, std::minus<>());
1012
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
1013
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
1012
+ sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
1013
+ sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
1014
1014
  q8 += 8;
1015
1015
  }
1016
1016
  const float d =
@@ -0,0 +1,141 @@
1
+ #include <sycl/sycl.hpp>
2
+ #include "wkv6.hpp"
3
+
4
+ constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
5
+
6
+ // Helper function for the main kernel
7
+ static void rwkv_wkv_f32_kernel(
8
+ const int B, const int T, const int C, const int H,
9
+ const float* k, const float* v, const float* r,
10
+ const float* tf, const float* td, const float* s,
11
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
12
+
13
+ const int tid = item_ct1.get_local_id(2);
14
+ const int bid = item_ct1.get_group(2);
15
+
16
+ const int head_size = WKV_BLOCK_SIZE;
17
+ const int batch_i = bid / H;
18
+ const int head_i = bid % H;
19
+ const int state_size = C * head_size;
20
+ const int n_seq_tokens = T / B;
21
+
22
+ // Set up shared memory pointers
23
+ float* _k = shared_mem;
24
+ float* _r = _k + head_size;
25
+ float* _tf = _r + head_size;
26
+ float* _td = _tf + head_size;
27
+
28
+ // Local state array
29
+ float state[WKV_BLOCK_SIZE];
30
+
31
+ // Load initial state
32
+ #pragma unroll
33
+ for (int i = 0; i < head_size; i++) {
34
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
35
+ }
36
+
37
+ // Sync threads before shared memory operations
38
+ item_ct1.barrier(sycl::access::fence_space::local_space);
39
+
40
+ // Load time-mixing parameters
41
+ _tf[tid] = tf[head_i * head_size + tid];
42
+ item_ct1.barrier(sycl::access::fence_space::local_space);
43
+
44
+ // Main sequence processing loop
45
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
46
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
47
+ t += C) {
48
+
49
+ item_ct1.barrier(sycl::access::fence_space::local_space);
50
+
51
+ // Load current timestep data to shared memory
52
+ _k[tid] = k[t];
53
+ _r[tid] = r[t];
54
+ _td[tid] = td[t];
55
+
56
+ item_ct1.barrier(sycl::access::fence_space::local_space);
57
+
58
+ const float _v = v[t];
59
+ float y = 0;
60
+
61
+ // Process in chunks of 4 for better vectorization
62
+ sycl::float4 k4, r4, tf4, td4, s4;
63
+ #pragma unroll
64
+ for (int j = 0; j < head_size; j += 4) {
65
+ // Load data in vec4 chunks
66
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
67
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
68
+ tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
69
+ td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
70
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
71
+
72
+ // Compute key-value product
73
+ sycl::float4 kv4 = k4 * _v;
74
+
75
+ // Accumulate weighted sum
76
+ y += sycl::dot(r4, tf4 * kv4 + s4);
77
+
78
+ // Update state
79
+ s4 = s4 * td4 + kv4;
80
+
81
+ // Store updated state
82
+ state[j] = s4.x();
83
+ state[j+1] = s4.y();
84
+ state[j+2] = s4.z();
85
+ state[j+3] = s4.w();
86
+ }
87
+
88
+ dst[t] = y;
89
+ }
90
+
91
+ // Save final state
92
+ #pragma unroll
93
+ for (int i = 0; i < head_size; i++) {
94
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
95
+ }
96
+ }
97
+
98
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
99
+ const ggml_tensor* src1, ggml_tensor* dst) {
100
+
101
+ const float* k_d = (const float*)dst->src[0]->data;
102
+ const float* v_d = (const float*)dst->src[1]->data;
103
+ const float* r_d = (const float*)dst->src[2]->data;
104
+ const float* tf_d = (const float*)dst->src[3]->data;
105
+ const float* td_d = (const float*)dst->src[4]->data;
106
+ const float* s_d = (const float*)dst->src[5]->data;
107
+ float* dst_d = (float*)dst->data;
108
+
109
+ const int64_t B = dst->src[5]->ne[1];
110
+ const int64_t T = dst->src[0]->ne[3];
111
+ const int64_t C = dst->ne[0];
112
+ const int64_t H = dst->src[0]->ne[2];
113
+
114
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
115
+ GGML_ASSERT(C % H == 0);
116
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
117
+
118
+ dpct::queue_ptr stream = ctx.stream();
119
+
120
+ // Calculate execution configuration
121
+ const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
122
+ sycl::range<3> block_dims(1, 1, C / H);
123
+ sycl::range<3> grid_dims(1, 1, B * H);
124
+
125
+ // Submit kernel
126
+ stream->submit([&](sycl::handler& cgh) {
127
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
128
+
129
+ cgh.parallel_for(
130
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
131
+ [=](sycl::nd_item<3> item_ct1) {
132
+ rwkv_wkv_f32_kernel(
133
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
134
+ item_ct1, shared_mem_acc.get_pointer()
135
+ );
136
+ });
137
+ });
138
+
139
+ GGML_UNUSED(src0);
140
+ GGML_UNUSED(src1);
141
+ }
@@ -0,0 +1,10 @@
1
+ #ifndef GGML_SYCL_WKV6_HPP
2
+ #define GGML_SYCL_WKV6_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
7
+ const ggml_tensor *src1, ggml_tensor * dst);
8
+
9
+
10
+ #endif // GGML_SYCL_WKV6_HPP
@@ -0,0 +1,12 @@
1
+ #include "ggml-threading.h"
2
+ #include <mutex>
3
+
4
+ std::mutex ggml_critical_section_mutex;
5
+
6
+ void ggml_critical_section_start() {
7
+ ggml_critical_section_mutex.lock();
8
+ }
9
+
10
+ void ggml_critical_section_end(void) {
11
+ ggml_critical_section_mutex.unlock();
12
+ }
@@ -0,0 +1,14 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+
5
+ #ifdef __cplusplus
6
+ extern "C" {
7
+ #endif
8
+
9
+ GGML_API void ggml_critical_section_start(void);
10
+ GGML_API void ggml_critical_section_end(void);
11
+
12
+ #ifdef __cplusplus
13
+ }
14
+ #endif
@@ -0,0 +1,92 @@
1
+ find_package(Vulkan COMPONENTS glslc REQUIRED)
2
+
3
+ if (Vulkan_FOUND)
4
+ message(STATUS "Vulkan found")
5
+
6
+ ggml_add_backend_library(ggml-vulkan
7
+ ggml-vulkan.cpp
8
+ ../../include/ggml-vulkan.h
9
+ )
10
+
11
+ # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
12
+ # If it's not, there will be an error to stderr.
13
+ # If it's supported, set a define to indicate that we should compile those shaders
14
+ execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
15
+ OUTPUT_VARIABLE glslc_output
16
+ ERROR_VARIABLE glslc_error)
17
+
18
+ if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
19
+ message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
20
+ else()
21
+ message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
22
+ add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
23
+ endif()
24
+
25
+ target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
26
+ target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
27
+
28
+ # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
29
+ # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
30
+ if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
31
+ add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
32
+ endif()
33
+
34
+ if (GGML_VULKAN_CHECK_RESULTS)
35
+ add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
36
+ endif()
37
+
38
+ if (GGML_VULKAN_DEBUG)
39
+ add_compile_definitions(GGML_VULKAN_DEBUG)
40
+ endif()
41
+
42
+ if (GGML_VULKAN_MEMORY_DEBUG)
43
+ add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
44
+ endif()
45
+
46
+ if (GGML_VULKAN_SHADER_DEBUG_INFO)
47
+ add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
48
+ endif()
49
+
50
+ if (GGML_VULKAN_PERF)
51
+ add_compile_definitions(GGML_VULKAN_PERF)
52
+ endif()
53
+
54
+ if (GGML_VULKAN_VALIDATE)
55
+ add_compile_definitions(GGML_VULKAN_VALIDATE)
56
+ endif()
57
+
58
+ if (GGML_VULKAN_RUN_TESTS)
59
+ add_compile_definitions(GGML_VULKAN_RUN_TESTS)
60
+ endif()
61
+
62
+ add_subdirectory(vulkan-shaders)
63
+
64
+ set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
65
+ set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
66
+ set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
67
+ set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
68
+ set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
69
+
70
+ file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
71
+
72
+ add_custom_command(
73
+ OUTPUT ${_ggml_vk_header}
74
+ ${_ggml_vk_source}
75
+
76
+ COMMAND ${_ggml_vk_genshaders_cmd}
77
+ --glslc ${Vulkan_GLSLC_EXECUTABLE}
78
+ --input-dir ${_ggml_vk_input_dir}
79
+ --output-dir ${_ggml_vk_output_dir}
80
+ --target-hpp ${_ggml_vk_header}
81
+ --target-cpp ${_ggml_vk_source}
82
+ --no-clean
83
+
84
+ DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd}
85
+ COMMENT "Generate vulkan shaders"
86
+ )
87
+
88
+ target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
89
+
90
+ else()
91
+ message(WARNING "Vulkan not found")
92
+ endif()