@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -20,6 +20,7 @@
20
20
  #include "shaderop_mul_mat_q8_0.h"
21
21
  #include "shaderop_mul_mat_q4_0.h"
22
22
  #include "shaderop_mul_mat_q4_1.h"
23
+ #include "shaderop_mul_mat_q4_k.h"
23
24
  #include "shaderop_mul_mat_q6_k.h"
24
25
  #include "shaderop_mul_mat_mat_f32.h"
25
26
  #include "shaderop_getrows_f32.h"
@@ -27,8 +28,10 @@
27
28
  #include "shaderop_getrows_q4_0.h"
28
29
  #include "shaderop_getrows_q4_1.h"
29
30
  #include "shaderop_getrows_q6_k.h"
30
- #include "shaderop_rope_f16.h"
31
- #include "shaderop_rope_f32.h"
31
+ #include "shaderop_rope_norm_f16.h"
32
+ #include "shaderop_rope_norm_f32.h"
33
+ #include "shaderop_rope_neox_f16.h"
34
+ #include "shaderop_rope_neox_f32.h"
32
35
  #include "shaderop_cpy_f16_f16.h"
33
36
  #include "shaderop_cpy_f16_f32.h"
34
37
  #include "shaderop_cpy_f32_f16.h"
@@ -42,6 +45,7 @@
42
45
  #include <cstring>
43
46
  #include <iostream>
44
47
  #include <memory>
48
+ #include <mutex>
45
49
  #include <stdexcept>
46
50
  #include <string>
47
51
  #include <unordered_map>
@@ -273,18 +277,9 @@ static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t mem
273
277
  return results;
274
278
  }
275
279
 
276
- // public API returns a C-style array
277
- ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) {
278
- auto devices = ggml_vk_available_devices_internal(memoryRequired);
279
- *count = devices.size();
280
- if (devices.empty()) {
281
- return nullptr;
282
- }
283
-
284
- size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
285
- auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
286
- memcpy(arr, devices.data(), nbytes);
287
- return arr;
280
+ static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
281
+ static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
282
+ return devices;
288
283
  }
289
284
 
290
285
  static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
@@ -341,7 +336,7 @@ ggml_vk_device ggml_vk_current_device() {
341
336
  if (!komputeManager()->hasDevice())
342
337
  return ggml_vk_device();
343
338
 
344
- auto devices = ggml_vk_available_devices_internal(0);
339
+ auto devices = ggml_vk_available_devices();
345
340
  ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
346
341
  GGML_ASSERT(!devices.empty());
347
342
  return devices.front();
@@ -352,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
352
347
  std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
353
348
  vk::DescriptorPoolSize(
354
349
  vk::DescriptorType::eStorageBuffer,
355
- 3 * size // Descriptor count is number of possible tensors to pass into an algorithm
350
+ 4 * size // Descriptor count is number of possible tensors to pass into an algorithm
356
351
  )
357
352
  };
358
353
 
@@ -795,7 +790,8 @@ static void ggml_vk_soft_max(
795
790
  const std::shared_ptr<kp::Tensor>& out,
796
791
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
797
792
  int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
798
- float scale
793
+ float scale, float max_bias, float m0, float m1,
794
+ uint32_t n_head_log2
799
795
  ) {
800
796
  const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
801
797
  kp::shader_data::op_softmax_comp_spv_len);
@@ -803,12 +799,14 @@ static void ggml_vk_soft_max(
803
799
  struct PushConstants {
804
800
  uint32_t inAOff, inBOff, outOff;
805
801
  int32_t ne00, ne01, ne02;
806
- float scale;
802
+ float scale, max_bias, m0, m1;
803
+ uint32_t n_head_log2;
807
804
  int32_t mask;
808
805
  } pushConsts {
809
806
  safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
810
807
  ne00, ne01, ne02,
811
- scale,
808
+ scale, max_bias, m0, m1,
809
+ n_head_log2,
812
810
  bool(inB)
813
811
  };
814
812
 
@@ -918,9 +916,9 @@ static void ggml_vk_mul_mat_f16(
918
916
  const std::shared_ptr<kp::Tensor>& out,
919
917
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
920
918
  int32_t ne00, int32_t ne01, int32_t ne02,
921
- uint32_t nb00, uint32_t nb01, uint32_t nb02,
919
+ uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
922
920
  int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
923
- uint32_t nb10, uint32_t nb11, uint32_t nb12,
921
+ uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
924
922
  int32_t ne0, int32_t ne1,
925
923
  uint32_t r2, uint32_t r3
926
924
  ) {
@@ -930,17 +928,17 @@ static void ggml_vk_mul_mat_f16(
930
928
  struct PushConstants {
931
929
  uint32_t inAOff, inBOff, outOff;
932
930
  int32_t ne00, ne01, ne02;
933
- uint32_t nb00, nb01, nb02;
931
+ uint32_t nb00, nb01, nb02, nb03;
934
932
  int32_t ne10, ne11, ne12;
935
- uint32_t nb10, nb11, nb12;
933
+ uint32_t nb10, nb11, nb12, nb13;
936
934
  int32_t ne0, ne1;
937
935
  uint32_t r2, r3;
938
936
  } pushConsts {
939
937
  safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
940
938
  ne00, ne01, ne02,
941
- nb00, nb01, nb02,
939
+ nb00, nb01, nb02, nb03,
942
940
  ne10, ne11, ne12,
943
- nb10, nb11, nb12,
941
+ nb10, nb11, nb12, nb13,
944
942
  ne0, ne1,
945
943
  r2, r3
946
944
  };
@@ -1020,6 +1018,8 @@ static void ggml_vk_mul_mat_impl(
1020
1018
  int32_t ne00, int32_t ne01, int32_t ne02,
1021
1019
  int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1022
1020
  int32_t ne0, int32_t ne1,
1021
+ uint32_t nb01, uint32_t nb02, uint32_t nb03,
1022
+ uint32_t nb11, uint32_t nb12, uint32_t nb13,
1023
1023
  uint32_t r2, uint32_t r3
1024
1024
  ) {
1025
1025
  struct PushConstants {
@@ -1027,19 +1027,23 @@ static void ggml_vk_mul_mat_impl(
1027
1027
  int32_t ne00, ne01, ne02;
1028
1028
  int32_t ne10, ne12;
1029
1029
  int32_t ne0, ne1;
1030
+ uint32_t nb01, nb02, nb03;
1031
+ uint32_t nb11, nb12, nb13;
1030
1032
  uint32_t r2, r3;
1031
1033
  } pushConsts {
1032
1034
  safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1033
1035
  ne00, ne01, ne02,
1034
1036
  ne10, ne12,
1035
1037
  ne0, ne1,
1038
+ nb01, nb02, nb03,
1039
+ nb11, nb12, nb13,
1036
1040
  r2, r3
1037
1041
  };
1038
1042
 
1039
1043
  auto name = std::string(__func__) + "_" + suffix;
1040
1044
  std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1041
1045
  if (!komputeManager()->hasAlgorithm(name)) {
1042
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1046
+ const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
1043
1047
  s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
1044
1048
  } else {
1045
1049
  s_algo = komputeManager()->getAlgorithm(name);
@@ -1075,34 +1079,84 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
1075
1079
  ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1076
1080
  }
1077
1081
 
1082
+ static void ggml_vk_mul_mat_q4_k(
1083
+ kp::Sequence& seq,
1084
+ const std::shared_ptr<kp::Tensor>& inA,
1085
+ const std::shared_ptr<kp::Tensor>& inB,
1086
+ const std::shared_ptr<kp::Tensor>& out,
1087
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1088
+ int32_t ne00, int32_t ne01, int32_t ne02,
1089
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1090
+ int32_t ne0, int32_t ne1,
1091
+ uint32_t nb01, uint32_t nb02, uint32_t nb03,
1092
+ uint32_t nb11, uint32_t nb12, uint32_t nb13,
1093
+ uint32_t r2, uint32_t r3
1094
+ ) {
1095
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
1096
+ kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
1097
+
1098
+ struct PushConstants {
1099
+ uint32_t inAOff, inBOff, outOff;
1100
+ int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
1101
+ uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
1102
+ uint32_t r2, r3;
1103
+ } pushConsts {
1104
+ inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1105
+ ne00, ne10, ne0, ne1, ne01, ne02, ne12,
1106
+ nb01, nb02, nb03, nb11, nb12, nb13,
1107
+ r2, r3
1108
+ };
1109
+
1110
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1111
+ if (!komputeManager()->hasAlgorithm(__func__)) {
1112
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
1113
+ } else {
1114
+ s_algo = komputeManager()->getAlgorithm(__func__);
1115
+ s_algo->setTensors({inA, inB, out});
1116
+ s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
1117
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1118
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1119
+ }
1120
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1121
+ }
1122
+
1078
1123
  static void ggml_vk_mul_mat_q6_k(
1079
1124
  kp::Sequence& seq,
1080
1125
  const std::shared_ptr<kp::Tensor>& inA,
1081
1126
  const std::shared_ptr<kp::Tensor>& inB,
1082
1127
  const std::shared_ptr<kp::Tensor>& out,
1083
1128
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1084
- int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
1085
- int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
1129
+ int32_t ne00, int32_t ne01, int32_t ne02,
1130
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1131
+ int32_t ne0, int32_t ne1,
1132
+ uint32_t nb01, uint32_t nb02, uint32_t nb03,
1133
+ uint32_t nb11, uint32_t nb12, uint32_t nb13,
1134
+ uint32_t r2, uint32_t r3
1086
1135
  ) {
1087
1136
  const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
1088
1137
  kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
1089
1138
 
1090
1139
  struct PushConstants {
1091
1140
  uint32_t inAOff, inBOff, outOff;
1092
- int32_t ne00, ne10, ne0, ne1, ne01, gqa;
1141
+ int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
1142
+ uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
1143
+ uint32_t r2, r3;
1093
1144
  } pushConsts {
1094
1145
  inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1095
- ne00, ne10, ne0, ne1, ne01, ne12/ne02
1146
+ ne00, ne10, ne0, ne1, ne01, ne02, ne12,
1147
+ nb01, nb02, nb03, nb11, nb12, nb13,
1148
+ r2, r3
1096
1149
  };
1097
1150
 
1098
1151
  std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1099
1152
  if (!komputeManager()->hasAlgorithm(__func__)) {
1100
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1101
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
1153
+ const uint32_t local_x = 2;
1154
+ const uint32_t local_y = ggml_vk_current_device().subgroupSize;
1155
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
1102
1156
  } else {
1103
1157
  s_algo = komputeManager()->getAlgorithm(__func__);
1104
1158
  s_algo->setTensors({inA, inB, out});
1105
- s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
1159
+ s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
1106
1160
  s_algo->setPushConstants<PushConstants>({pushConsts});
1107
1161
  s_algo->updateDescriptors(s_kompute_context->pool.get());
1108
1162
  }
@@ -1190,10 +1244,11 @@ static void ggml_vk_rope(
1190
1244
  kp::Sequence& seq,
1191
1245
  const std::shared_ptr<kp::Tensor>& inA,
1192
1246
  const std::shared_ptr<kp::Tensor>& inB,
1247
+ const std::shared_ptr<kp::Tensor>& inC,
1193
1248
  const std::shared_ptr<kp::Tensor>& out,
1194
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1249
+ uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
1195
1250
  ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1196
- float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1251
+ float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1197
1252
  int32_t ne01, int32_t ne02, int32_t ne03,
1198
1253
  uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1199
1254
  int32_t ne0,
@@ -1201,11 +1256,17 @@ static void ggml_vk_rope(
1201
1256
  ) {
1202
1257
  GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
1203
1258
 
1204
- static const auto spirv_f16 = getSpirvShader(
1205
- kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
1259
+ static const auto spirv_norm_f16 = getSpirvShader(
1260
+ kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
1261
+ );
1262
+ static const auto spirv_norm_f32 = getSpirvShader(
1263
+ kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
1206
1264
  );
1207
- static const auto spirv_f32 = getSpirvShader(
1208
- kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
1265
+ static const auto spirv_neox_f16 = getSpirvShader(
1266
+ kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
1267
+ );
1268
+ static const auto spirv_neox_f32 = getSpirvShader(
1269
+ kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
1209
1270
  );
1210
1271
 
1211
1272
  int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
@@ -1220,32 +1281,40 @@ static void ggml_vk_rope(
1220
1281
  GGML_ASSERT(nb0 % type_size == 0);
1221
1282
 
1222
1283
  struct PushConstants {
1223
- uint32_t inAOff, inBOff, outOff;
1284
+ uint32_t inAOff, inBOff, inCOff, outOff;
1224
1285
  int32_t n_dims, mode, n_ctx_orig;
1225
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1286
+ float freq_base, freq_scale;
1287
+ bool has_freq_factors;
1288
+ float ext_factor, attn_factor, beta_fast, beta_slow;
1226
1289
  uint32_t nb00, nb01, nb02, nb03;
1227
1290
  int32_t ne0;
1228
1291
  uint32_t nb0, nb1, nb2, nb3;
1229
1292
  } pushConsts {
1230
- safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1293
+ safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
1231
1294
  n_dims, mode, n_ctx_orig,
1232
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1295
+ freq_base, freq_scale,
1296
+ has_freq_factors,
1297
+ ext_factor, attn_factor, beta_fast, beta_slow,
1233
1298
  nb00, nb01, nb02, nb03,
1234
1299
  ne0,
1235
1300
  nb0, nb1, nb2, nb3
1236
1301
  };
1237
1302
 
1238
- auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1303
+ auto & inC_ = inC ? inC : inA;
1304
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
1305
+ const bool is_f16 = src0t == GGML_TYPE_F16;
1306
+
1307
+ auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1239
1308
  std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1240
1309
  if (!komputeManager()->hasAlgorithm(name)) {
1310
+ auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
1241
1311
  s_algo = komputeManager()->algorithm<float, PushConstants>(
1242
- name, s_kompute_context->pool.get(), {inA, inB, out},
1243
- src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
1312
+ name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
1244
1313
  {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
1245
1314
  );
1246
1315
  } else {
1247
1316
  s_algo = komputeManager()->getAlgorithm(name);
1248
- s_algo->setTensors({inA, inB, out});
1317
+ s_algo->setTensors({inA, inB, inC_, out});
1249
1318
  s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1250
1319
  s_algo->setPushConstants<PushConstants>({pushConsts});
1251
1320
  s_algo->updateDescriptors(s_kompute_context->pool.get());
@@ -1323,22 +1392,16 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
1323
1392
  ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
1324
1393
  }
1325
1394
 
1326
- static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1327
- switch (op->type) {
1328
- case GGML_TYPE_F16:
1329
- case GGML_TYPE_F32:
1330
- case GGML_TYPE_Q4_0:
1331
- case GGML_TYPE_Q4_1:
1332
- break;
1333
- default:
1334
- return false;
1335
- }
1336
-
1395
+ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1396
+ int64_t n = ggml_nelements(op);
1337
1397
  switch (op->op) {
1338
1398
  case GGML_OP_UNARY:
1399
+ if (n % 4 != 0) return false;
1339
1400
  switch (ggml_get_unary_op(op)) {
1340
- case GGML_UNARY_OP_RELU:
1341
1401
  case GGML_UNARY_OP_GELU:
1402
+ if (n % 8 != 0) return false;
1403
+ // fall through
1404
+ case GGML_UNARY_OP_RELU:
1342
1405
  case GGML_UNARY_OP_SILU:
1343
1406
  return ggml_is_contiguous(op->src[0]);
1344
1407
  default:
@@ -1356,8 +1419,18 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1356
1419
  case GGML_OP_SOFT_MAX:
1357
1420
  case GGML_OP_RMS_NORM:
1358
1421
  case GGML_OP_NORM:
1359
- case GGML_OP_ROPE:
1360
1422
  return true;
1423
+ case GGML_OP_ROPE:
1424
+ {
1425
+ const int mode = ((const int32_t *) op->op_params)[2];
1426
+ if (mode & GGML_ROPE_TYPE_MROPE) {
1427
+ return false;
1428
+ }
1429
+ if (mode & GGML_ROPE_TYPE_VISION) {
1430
+ return false;
1431
+ }
1432
+ return true;
1433
+ }
1361
1434
  case GGML_OP_DUP:
1362
1435
  case GGML_OP_CPY:
1363
1436
  case GGML_OP_CONT:
@@ -1396,12 +1469,13 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1396
1469
 
1397
1470
  switch (op->src[0]->type) {
1398
1471
  case GGML_TYPE_F32:
1399
- case GGML_TYPE_Q6_K:
1400
1472
  return op->ne[3] == 1;
1473
+ case GGML_TYPE_Q6_K:
1401
1474
  case GGML_TYPE_F16:
1402
1475
  case GGML_TYPE_Q8_0:
1403
1476
  case GGML_TYPE_Q4_0:
1404
1477
  case GGML_TYPE_Q4_1:
1478
+ case GGML_TYPE_Q4_K:
1405
1479
  return true;
1406
1480
  default:
1407
1481
  ;
@@ -1410,6 +1484,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1410
1484
  ;
1411
1485
  }
1412
1486
  return false;
1487
+
1488
+ GGML_UNUSED(dev);
1413
1489
  }
1414
1490
 
1415
1491
  static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
@@ -1458,11 +1534,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1458
1534
 
1459
1535
  any_commands_recorded = true;
1460
1536
 
1461
- if (!ggml_vk_supports_op(dst)) {
1462
- fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1463
- GGML_ABORT("unsupported op");
1464
- }
1465
-
1466
1537
  const int32_t ne00 = src0 ? src0->ne[0] : 0;
1467
1538
  const int32_t ne01 = src0 ? src0->ne[1] : 0;
1468
1539
  const int32_t ne02 = src0 ? src0->ne[2] : 0;
@@ -1500,9 +1571,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1500
1571
  const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
1501
1572
  uint32_t off_src0 = 0;
1502
1573
  uint32_t off_src1 = 0;
1574
+ uint32_t off_src2 = 0;
1503
1575
  uint32_t off_dst = 0;
1504
1576
  const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
1505
1577
  const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
1578
+ const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
1506
1579
  const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
1507
1580
 
1508
1581
  switch (dst->op) {
@@ -1578,11 +1651,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1578
1651
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1579
1652
  GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1580
1653
 
1581
- #pragma message("TODO: add ALiBi support")
1582
- #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1583
- GGML_ASSERT(max_bias == 0.0f);
1654
+ const int64_t nrows_x = ggml_nrows(src0);
1655
+ const int64_t nrows_y = src0->ne[1];
1656
+
1657
+ const uint32_t n_head = nrows_x/nrows_y;
1658
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1584
1659
 
1585
- ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1660
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1661
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1662
+
1663
+ ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
1586
1664
  } break;
1587
1665
  case GGML_OP_DIAG_MASK_INF:
1588
1666
  {
@@ -1634,32 +1712,44 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1634
1712
  case GGML_TYPE_F16:
1635
1713
  ggml_vk_mul_mat_f16(
1636
1714
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1637
- ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
1715
+ ne00, ne01, ne02, nb00, nb01, nb02, nb03,
1716
+ ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1638
1717
  ne0, ne1, r2, r3
1639
1718
  );
1640
1719
  break;
1641
1720
  case GGML_TYPE_Q8_0:
1642
1721
  ggml_vk_mul_mat_q8_0(
1643
1722
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1644
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1723
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1724
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1645
1725
  );
1646
1726
  break;
1647
1727
  case GGML_TYPE_Q4_0:
1648
1728
  ggml_vk_mul_mat_q4_0(
1649
1729
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1650
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1730
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1731
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1651
1732
  );
1652
1733
  break;
1653
1734
  case GGML_TYPE_Q4_1:
1654
1735
  ggml_vk_mul_mat_q4_1(
1655
1736
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1656
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1737
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1738
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1739
+ );
1740
+ break;
1741
+ case GGML_TYPE_Q4_K:
1742
+ ggml_vk_mul_mat_q4_k(
1743
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1744
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1745
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1657
1746
  );
1658
1747
  break;
1659
1748
  case GGML_TYPE_Q6_K:
1660
1749
  ggml_vk_mul_mat_q6_k(
1661
1750
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1662
- ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
1751
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1752
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1663
1753
  );
1664
1754
  break;
1665
1755
  default: {
@@ -1688,13 +1778,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1688
1778
  } break;
1689
1779
  case GGML_OP_ROPE:
1690
1780
  {
1691
- #pragma message("TODO: implement phi3 frequency factors support")
1692
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1693
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1694
-
1695
- #pragma message("TODO: update rope NORM mode to match NEOX mode")
1696
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
1697
-
1698
1781
  GGML_ASSERT(ne10 == ne02);
1699
1782
  GGML_ASSERT(src0t == dstt);
1700
1783
  // const int n_past = ((int32_t *) dst->op_params)[0];
@@ -1703,6 +1786,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1703
1786
  // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1704
1787
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1705
1788
 
1789
+ const bool has_freq_factors = dst->src[2] != nullptr;
1790
+
1706
1791
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1707
1792
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1708
1793
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@@ -1711,8 +1796,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1711
1796
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1712
1797
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1713
1798
  ggml_vk_rope(
1714
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1715
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1799
+ seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
1800
+ freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
1716
1801
  ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1717
1802
  );
1718
1803
  } break;
@@ -1820,11 +1905,6 @@ static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
1820
1905
  }
1821
1906
  }
1822
1907
 
1823
- static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) {
1824
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buffer->buft->context);
1825
- return ctx->name.c_str();
1826
- }
1827
-
1828
1908
  static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1829
1909
  auto * memory = (ggml_vk_memory *)buffer->context;
1830
1910
  if (ggml_vk_has_device()) {
@@ -1868,7 +1948,6 @@ static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint
1868
1948
  }
1869
1949
 
1870
1950
  static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
1871
- /* .get_name = */ ggml_backend_kompute_buffer_get_name,
1872
1951
  /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
1873
1952
  /* .get_base = */ ggml_backend_kompute_buffer_get_base,
1874
1953
  /* .init_tensor = */ NULL,
@@ -1913,25 +1992,31 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
1913
1992
  };
1914
1993
 
1915
1994
  ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
1916
- static std::vector<ggml_backend_buffer_type> bufts = []() {
1917
- std::vector<ggml_backend_buffer_type> vec;
1918
- auto devices = ggml_vk_available_devices_internal(0);
1919
- vec.reserve(devices.size());
1920
-
1921
- for (const auto & dev : devices) {
1922
- vec.push_back({
1923
- /* .iface = */ ggml_backend_kompute_buffer_type_interface,
1924
- /* .device = */ nullptr,
1925
- /* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
1926
- });
1995
+ static std::mutex mutex;
1996
+ std::lock_guard<std::mutex> lock(mutex);
1997
+
1998
+ auto devices = ggml_vk_available_devices();
1999
+ int32_t device_count = (int32_t) devices.size();
2000
+ GGML_ASSERT(device < device_count);
2001
+ GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
2002
+
2003
+ static ggml_backend_buffer_type
2004
+ ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
2005
+
2006
+ static bool ggml_backend_kompute_buffer_type_initialized = false;
2007
+
2008
+ if (!ggml_backend_kompute_buffer_type_initialized) {
2009
+ for (int32_t i = 0; i < device_count; i++) {
2010
+ ggml_backend_kompute_buffer_types[i] = {
2011
+ /* .iface = */ ggml_backend_kompute_buffer_type_interface,
2012
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
2013
+ /* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
2014
+ };
1927
2015
  }
1928
- return vec;
1929
- }();
2016
+ ggml_backend_kompute_buffer_type_initialized = true;
2017
+ }
1930
2018
 
1931
- auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
1932
- return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
1933
- });
1934
- return it < bufts.end() ? &*it : nullptr;
2019
+ return &ggml_backend_kompute_buffer_types[device];
1935
2020
  }
1936
2021
 
1937
2022
  // backend
@@ -1953,31 +2038,15 @@ static void ggml_backend_kompute_free(ggml_backend_t backend) {
1953
2038
  delete backend;
1954
2039
  }
1955
2040
 
1956
- static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) {
1957
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1958
- return ggml_backend_kompute_buffer_type(ctx->device);
1959
- }
1960
-
1961
2041
  static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1962
2042
  auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1963
2043
  ggml_vk_graph_compute(ctx, cgraph);
1964
2044
  return GGML_STATUS_SUCCESS;
1965
2045
  }
1966
2046
 
1967
- static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1968
- GGML_UNUSED(backend);
1969
- return ggml_vk_supports_op(op);
1970
- }
1971
-
1972
- static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1973
- GGML_UNUSED(backend);
1974
- return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
1975
- }
1976
-
1977
2047
  static struct ggml_backend_i kompute_backend_i = {
1978
2048
  /* .get_name = */ ggml_backend_kompute_name,
1979
2049
  /* .free = */ ggml_backend_kompute_free,
1980
- /* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type,
1981
2050
  /* .set_tensor_async = */ NULL,
1982
2051
  /* .get_tensor_async = */ NULL,
1983
2052
  /* .cpy_tensor_async = */ NULL,
@@ -1987,9 +2056,6 @@ static struct ggml_backend_i kompute_backend_i = {
1987
2056
  /* .graph_plan_update = */ NULL,
1988
2057
  /* .graph_plan_compute = */ NULL,
1989
2058
  /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1990
- /* .supports_op = */ ggml_backend_kompute_supports_op,
1991
- /* .supports_buft = */ ggml_backend_kompute_supports_buft,
1992
- /* .offload_op = */ NULL,
1993
2059
  /* .event_record = */ NULL,
1994
2060
  /* .event_wait = */ NULL,
1995
2061
  };
@@ -2006,7 +2072,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
2006
2072
  ggml_backend_t kompute_backend = new ggml_backend {
2007
2073
  /* .guid = */ ggml_backend_kompute_guid(),
2008
2074
  /* .interface = */ kompute_backend_i,
2009
- /* .device = */ nullptr,
2075
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
2010
2076
  /* .context = */ s_kompute_context,
2011
2077
  };
2012
2078
 
@@ -2016,3 +2082,170 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
2016
2082
  bool ggml_backend_is_kompute(ggml_backend_t backend) {
2017
2083
  return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
2018
2084
  }
2085
+
2086
+ static size_t ggml_backend_kompute_get_device_count() {
2087
+ auto devices = ggml_vk_available_devices();
2088
+ return devices.size();
2089
+ }
2090
+
2091
+ static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
2092
+ auto devices = ggml_vk_available_devices();
2093
+ GGML_ASSERT((size_t) device < devices.size());
2094
+ snprintf(description, description_size, "%s", devices[device].name);
2095
+ }
2096
+
2097
+ static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
2098
+ auto devices = ggml_vk_available_devices();
2099
+ GGML_ASSERT((size_t) device < devices.size());
2100
+ *total = devices[device].heapSize;
2101
+ *free = devices[device].heapSize;
2102
+ }
2103
+
2104
+ //////////////////////////
2105
+
2106
+ struct ggml_backend_kompute_device_context {
2107
+ int device;
2108
+ std::string name;
2109
+ std::string description;
2110
+ };
2111
+
2112
+ static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
2113
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2114
+ return ctx->name.c_str();
2115
+ }
2116
+
2117
+ static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
2118
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2119
+ return ctx->description.c_str();
2120
+ }
2121
+
2122
+ static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2123
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2124
+ ggml_backend_kompute_get_device_memory(ctx->device, free, total);
2125
+ }
2126
+
2127
+ static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
2128
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2129
+ return ggml_backend_kompute_buffer_type(ctx->device);
2130
+ }
2131
+
2132
+ static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2133
+ if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
2134
+ return false;
2135
+ }
2136
+
2137
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2138
+ ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
2139
+
2140
+ return buft_ctx->device == ctx->device;
2141
+ }
2142
+
2143
+ static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
2144
+ GGML_UNUSED(dev);
2145
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
2146
+ }
2147
+
2148
+ static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2149
+ props->name = ggml_backend_kompute_device_get_name(dev);
2150
+ props->description = ggml_backend_kompute_device_get_description(dev);
2151
+ props->type = ggml_backend_kompute_device_get_type(dev);
2152
+ ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
2153
+ props->caps = {
2154
+ /* async = */ false,
2155
+ /* host_buffer = */ false,
2156
+ /* .buffer_from_host_ptr = */ false,
2157
+ /* events = */ false,
2158
+ };
2159
+ }
2160
+
2161
+ static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
2162
+ GGML_UNUSED(params);
2163
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2164
+ return ggml_backend_kompute_init(ctx->device);
2165
+ }
2166
+
2167
+ static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2168
+ const int min_batch_size = 32;
2169
+
2170
+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2171
+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2172
+
2173
+ GGML_UNUSED(dev);
2174
+ }
2175
+
2176
+ static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
2177
+ /* .get_name = */ ggml_backend_kompute_device_get_name,
2178
+ /* .get_description = */ ggml_backend_kompute_device_get_description,
2179
+ /* .get_memory = */ ggml_backend_kompute_device_get_memory,
2180
+ /* .get_type = */ ggml_backend_kompute_device_get_type,
2181
+ /* .get_props = */ ggml_backend_kompute_device_get_props,
2182
+ /* .init_backend = */ ggml_backend_kompute_device_init,
2183
+ /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
2184
+ /* .get_host_buffer_type = */ NULL,
2185
+ /* .buffer_from_host_ptr = */ NULL,
2186
+ /* .supports_op = */ ggml_backend_kompute_device_supports_op,
2187
+ /* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
2188
+ /* .offload_op = */ ggml_backend_kompute_device_offload_op,
2189
+ /* .event_new = */ NULL,
2190
+ /* .event_free = */ NULL,
2191
+ /* .event_synchronize = */ NULL,
2192
+ };
2193
+
2194
+ static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
2195
+ GGML_UNUSED(reg);
2196
+ return "Kompute";
2197
+ }
2198
+
2199
+ static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
2200
+ GGML_UNUSED(reg);
2201
+ return ggml_backend_kompute_get_device_count();
2202
+ }
2203
+
2204
+ static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
2205
+ static std::vector<ggml_backend_dev_t> devices;
2206
+
2207
+ static bool initialized = false;
2208
+
2209
+ {
2210
+ static std::mutex mutex;
2211
+ std::lock_guard<std::mutex> lock(mutex);
2212
+ if (!initialized) {
2213
+ for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
2214
+ ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
2215
+ char desc[256];
2216
+ ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
2217
+ ctx->device = i;
2218
+ ctx->name = "Kompute" + std::to_string(i);
2219
+ ctx->description = desc;
2220
+ devices.push_back(new ggml_backend_device {
2221
+ /* .iface = */ ggml_backend_kompute_device_i,
2222
+ /* .reg = */ reg,
2223
+ /* .context = */ ctx,
2224
+ });
2225
+ }
2226
+ initialized = true;
2227
+ }
2228
+ }
2229
+
2230
+ GGML_ASSERT(device < devices.size());
2231
+ return devices[device];
2232
+ }
2233
+
2234
+ static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
2235
+ /* .get_name = */ ggml_backend_kompute_reg_get_name,
2236
+ /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
2237
+ /* .get_device = */ ggml_backend_kompute_reg_get_device,
2238
+ /* .get_proc_address = */ NULL,
2239
+ };
2240
+
2241
+ ggml_backend_reg_t ggml_backend_kompute_reg() {
2242
+ static ggml_backend_reg reg = {
2243
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
2244
+ /* .iface = */ ggml_backend_kompute_reg_i,
2245
+ /* .context = */ nullptr,
2246
+ };
2247
+
2248
+ return &reg;
2249
+ }
2250
+
2251
+ GGML_BACKEND_DL_IMPL(ggml_backend_kompute_reg)