@fugood/llama.node 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +1 -1
  21. package/src/LlamaContext.cpp +81 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -28,8 +28,10 @@
28
28
  #include "shaderop_getrows_q4_0.h"
29
29
  #include "shaderop_getrows_q4_1.h"
30
30
  #include "shaderop_getrows_q6_k.h"
31
- #include "shaderop_rope_f16.h"
32
- #include "shaderop_rope_f32.h"
31
+ #include "shaderop_rope_norm_f16.h"
32
+ #include "shaderop_rope_norm_f32.h"
33
+ #include "shaderop_rope_neox_f16.h"
34
+ #include "shaderop_rope_neox_f32.h"
33
35
  #include "shaderop_cpy_f16_f16.h"
34
36
  #include "shaderop_cpy_f16_f32.h"
35
37
  #include "shaderop_cpy_f32_f16.h"
@@ -345,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
345
347
  std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
346
348
  vk::DescriptorPoolSize(
347
349
  vk::DescriptorType::eStorageBuffer,
348
- 3 * size // Descriptor count is number of possible tensors to pass into an algorithm
350
+ 4 * size // Descriptor count is number of possible tensors to pass into an algorithm
349
351
  )
350
352
  };
351
353
 
@@ -788,7 +790,8 @@ static void ggml_vk_soft_max(
788
790
  const std::shared_ptr<kp::Tensor>& out,
789
791
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
790
792
  int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
791
- float scale
793
+ float scale, float max_bias, float m0, float m1,
794
+ uint32_t n_head_log2
792
795
  ) {
793
796
  const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
794
797
  kp::shader_data::op_softmax_comp_spv_len);
@@ -796,12 +799,14 @@ static void ggml_vk_soft_max(
796
799
  struct PushConstants {
797
800
  uint32_t inAOff, inBOff, outOff;
798
801
  int32_t ne00, ne01, ne02;
799
- float scale;
802
+ float scale, max_bias, m0, m1;
803
+ uint32_t n_head_log2;
800
804
  int32_t mask;
801
805
  } pushConsts {
802
806
  safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
803
807
  ne00, ne01, ne02,
804
- scale,
808
+ scale, max_bias, m0, m1,
809
+ n_head_log2,
805
810
  bool(inB)
806
811
  };
807
812
 
@@ -911,9 +916,9 @@ static void ggml_vk_mul_mat_f16(
911
916
  const std::shared_ptr<kp::Tensor>& out,
912
917
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
913
918
  int32_t ne00, int32_t ne01, int32_t ne02,
914
- uint32_t nb00, uint32_t nb01, uint32_t nb02,
919
+ uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
915
920
  int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
916
- uint32_t nb10, uint32_t nb11, uint32_t nb12,
921
+ uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
917
922
  int32_t ne0, int32_t ne1,
918
923
  uint32_t r2, uint32_t r3
919
924
  ) {
@@ -923,17 +928,17 @@ static void ggml_vk_mul_mat_f16(
923
928
  struct PushConstants {
924
929
  uint32_t inAOff, inBOff, outOff;
925
930
  int32_t ne00, ne01, ne02;
926
- uint32_t nb00, nb01, nb02;
931
+ uint32_t nb00, nb01, nb02, nb03;
927
932
  int32_t ne10, ne11, ne12;
928
- uint32_t nb10, nb11, nb12;
933
+ uint32_t nb10, nb11, nb12, nb13;
929
934
  int32_t ne0, ne1;
930
935
  uint32_t r2, r3;
931
936
  } pushConsts {
932
937
  safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
933
938
  ne00, ne01, ne02,
934
- nb00, nb01, nb02,
939
+ nb00, nb01, nb02, nb03,
935
940
  ne10, ne11, ne12,
936
- nb10, nb11, nb12,
941
+ nb10, nb11, nb12, nb13,
937
942
  ne0, ne1,
938
943
  r2, r3
939
944
  };
@@ -1013,6 +1018,8 @@ static void ggml_vk_mul_mat_impl(
1013
1018
  int32_t ne00, int32_t ne01, int32_t ne02,
1014
1019
  int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1015
1020
  int32_t ne0, int32_t ne1,
1021
+ uint32_t nb01, uint32_t nb02, uint32_t nb03,
1022
+ uint32_t nb11, uint32_t nb12, uint32_t nb13,
1016
1023
  uint32_t r2, uint32_t r3
1017
1024
  ) {
1018
1025
  struct PushConstants {
@@ -1020,19 +1027,23 @@ static void ggml_vk_mul_mat_impl(
1020
1027
  int32_t ne00, ne01, ne02;
1021
1028
  int32_t ne10, ne12;
1022
1029
  int32_t ne0, ne1;
1030
+ uint32_t nb01, nb02, nb03;
1031
+ uint32_t nb11, nb12, nb13;
1023
1032
  uint32_t r2, r3;
1024
1033
  } pushConsts {
1025
1034
  safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1026
1035
  ne00, ne01, ne02,
1027
1036
  ne10, ne12,
1028
1037
  ne0, ne1,
1038
+ nb01, nb02, nb03,
1039
+ nb11, nb12, nb13,
1029
1040
  r2, r3
1030
1041
  };
1031
1042
 
1032
1043
  auto name = std::string(__func__) + "_" + suffix;
1033
1044
  std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1034
1045
  if (!komputeManager()->hasAlgorithm(name)) {
1035
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1046
+ const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
1036
1047
  s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
1037
1048
  } else {
1038
1049
  s_algo = komputeManager()->getAlgorithm(name);
@@ -1074,19 +1085,26 @@ static void ggml_vk_mul_mat_q4_k(
1074
1085
  const std::shared_ptr<kp::Tensor>& inB,
1075
1086
  const std::shared_ptr<kp::Tensor>& out,
1076
1087
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1077
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
1078
- int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
1079
- int32_t ne1, int32_t r2, int32_t r3
1088
+ int32_t ne00, int32_t ne01, int32_t ne02,
1089
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1090
+ int32_t ne0, int32_t ne1,
1091
+ uint32_t nb01, uint32_t nb02, uint32_t nb03,
1092
+ uint32_t nb11, uint32_t nb12, uint32_t nb13,
1093
+ uint32_t r2, uint32_t r3
1080
1094
  ) {
1081
1095
  const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
1082
1096
  kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
1083
1097
 
1084
1098
  struct PushConstants {
1085
1099
  uint32_t inAOff, inBOff, outOff;
1086
- int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
1100
+ int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
1101
+ uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
1102
+ uint32_t r2, r3;
1087
1103
  } pushConsts {
1088
- 0, 0, 0,
1089
- ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
1104
+ inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1105
+ ne00, ne10, ne0, ne1, ne01, ne02, ne12,
1106
+ nb01, nb02, nb03, nb11, nb12, nb13,
1107
+ r2, r3
1090
1108
  };
1091
1109
 
1092
1110
  std::shared_ptr<kp::Algorithm> s_algo = nullptr;
@@ -1108,28 +1126,37 @@ static void ggml_vk_mul_mat_q6_k(
1108
1126
  const std::shared_ptr<kp::Tensor>& inB,
1109
1127
  const std::shared_ptr<kp::Tensor>& out,
1110
1128
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1111
- int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
1112
- int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
1129
+ int32_t ne00, int32_t ne01, int32_t ne02,
1130
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1131
+ int32_t ne0, int32_t ne1,
1132
+ uint32_t nb01, uint32_t nb02, uint32_t nb03,
1133
+ uint32_t nb11, uint32_t nb12, uint32_t nb13,
1134
+ uint32_t r2, uint32_t r3
1113
1135
  ) {
1114
1136
  const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
1115
1137
  kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
1116
1138
 
1117
1139
  struct PushConstants {
1118
1140
  uint32_t inAOff, inBOff, outOff;
1119
- int32_t ne00, ne10, ne0, ne1, ne01, gqa;
1141
+ int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
1142
+ uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
1143
+ uint32_t r2, r3;
1120
1144
  } pushConsts {
1121
1145
  inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1122
- ne00, ne10, ne0, ne1, ne01, ne12/ne02
1146
+ ne00, ne10, ne0, ne1, ne01, ne02, ne12,
1147
+ nb01, nb02, nb03, nb11, nb12, nb13,
1148
+ r2, r3
1123
1149
  };
1124
1150
 
1125
1151
  std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1126
1152
  if (!komputeManager()->hasAlgorithm(__func__)) {
1127
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1128
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
1153
+ const uint32_t local_x = 2;
1154
+ const uint32_t local_y = ggml_vk_current_device().subgroupSize;
1155
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
1129
1156
  } else {
1130
1157
  s_algo = komputeManager()->getAlgorithm(__func__);
1131
1158
  s_algo->setTensors({inA, inB, out});
1132
- s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
1159
+ s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
1133
1160
  s_algo->setPushConstants<PushConstants>({pushConsts});
1134
1161
  s_algo->updateDescriptors(s_kompute_context->pool.get());
1135
1162
  }
@@ -1217,10 +1244,11 @@ static void ggml_vk_rope(
1217
1244
  kp::Sequence& seq,
1218
1245
  const std::shared_ptr<kp::Tensor>& inA,
1219
1246
  const std::shared_ptr<kp::Tensor>& inB,
1247
+ const std::shared_ptr<kp::Tensor>& inC,
1220
1248
  const std::shared_ptr<kp::Tensor>& out,
1221
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1249
+ uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
1222
1250
  ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1223
- float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1251
+ float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1224
1252
  int32_t ne01, int32_t ne02, int32_t ne03,
1225
1253
  uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1226
1254
  int32_t ne0,
@@ -1228,11 +1256,17 @@ static void ggml_vk_rope(
1228
1256
  ) {
1229
1257
  GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
1230
1258
 
1231
- static const auto spirv_f16 = getSpirvShader(
1232
- kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
1259
+ static const auto spirv_norm_f16 = getSpirvShader(
1260
+ kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
1261
+ );
1262
+ static const auto spirv_norm_f32 = getSpirvShader(
1263
+ kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
1233
1264
  );
1234
- static const auto spirv_f32 = getSpirvShader(
1235
- kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
1265
+ static const auto spirv_neox_f16 = getSpirvShader(
1266
+ kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
1267
+ );
1268
+ static const auto spirv_neox_f32 = getSpirvShader(
1269
+ kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
1236
1270
  );
1237
1271
 
1238
1272
  int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
@@ -1247,32 +1281,40 @@ static void ggml_vk_rope(
1247
1281
  GGML_ASSERT(nb0 % type_size == 0);
1248
1282
 
1249
1283
  struct PushConstants {
1250
- uint32_t inAOff, inBOff, outOff;
1284
+ uint32_t inAOff, inBOff, inCOff, outOff;
1251
1285
  int32_t n_dims, mode, n_ctx_orig;
1252
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1286
+ float freq_base, freq_scale;
1287
+ bool has_freq_factors;
1288
+ float ext_factor, attn_factor, beta_fast, beta_slow;
1253
1289
  uint32_t nb00, nb01, nb02, nb03;
1254
1290
  int32_t ne0;
1255
1291
  uint32_t nb0, nb1, nb2, nb3;
1256
1292
  } pushConsts {
1257
- safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1293
+ safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
1258
1294
  n_dims, mode, n_ctx_orig,
1259
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1295
+ freq_base, freq_scale,
1296
+ has_freq_factors,
1297
+ ext_factor, attn_factor, beta_fast, beta_slow,
1260
1298
  nb00, nb01, nb02, nb03,
1261
1299
  ne0,
1262
1300
  nb0, nb1, nb2, nb3
1263
1301
  };
1264
1302
 
1265
- auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1303
+ auto & inC_ = inC ? inC : inA;
1304
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
1305
+ const bool is_f16 = src0t == GGML_TYPE_F16;
1306
+
1307
+ auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1266
1308
  std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1267
1309
  if (!komputeManager()->hasAlgorithm(name)) {
1310
+ auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
1268
1311
  s_algo = komputeManager()->algorithm<float, PushConstants>(
1269
- name, s_kompute_context->pool.get(), {inA, inB, out},
1270
- src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
1312
+ name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
1271
1313
  {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
1272
1314
  );
1273
1315
  } else {
1274
1316
  s_algo = komputeManager()->getAlgorithm(name);
1275
- s_algo->setTensors({inA, inB, out});
1317
+ s_algo->setTensors({inA, inB, inC_, out});
1276
1318
  s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1277
1319
  s_algo->setPushConstants<PushConstants>({pushConsts});
1278
1320
  s_algo->updateDescriptors(s_kompute_context->pool.get());
@@ -1351,11 +1393,15 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
1351
1393
  }
1352
1394
 
1353
1395
  static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1396
+ int64_t n = ggml_nelements(op);
1354
1397
  switch (op->op) {
1355
1398
  case GGML_OP_UNARY:
1399
+ if (n % 4 != 0) return false;
1356
1400
  switch (ggml_get_unary_op(op)) {
1357
- case GGML_UNARY_OP_RELU:
1358
1401
  case GGML_UNARY_OP_GELU:
1402
+ if (n % 8 != 0) return false;
1403
+ // fall through
1404
+ case GGML_UNARY_OP_RELU:
1359
1405
  case GGML_UNARY_OP_SILU:
1360
1406
  return ggml_is_contiguous(op->src[0]);
1361
1407
  default:
@@ -1373,8 +1419,18 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
1373
1419
  case GGML_OP_SOFT_MAX:
1374
1420
  case GGML_OP_RMS_NORM:
1375
1421
  case GGML_OP_NORM:
1376
- case GGML_OP_ROPE:
1377
1422
  return true;
1423
+ case GGML_OP_ROPE:
1424
+ {
1425
+ const int mode = ((const int32_t *) op->op_params)[2];
1426
+ if (mode & GGML_ROPE_TYPE_MROPE) {
1427
+ return false;
1428
+ }
1429
+ if (mode & GGML_ROPE_TYPE_VISION) {
1430
+ return false;
1431
+ }
1432
+ return true;
1433
+ }
1378
1434
  case GGML_OP_DUP:
1379
1435
  case GGML_OP_CPY:
1380
1436
  case GGML_OP_CONT:
@@ -1413,8 +1469,8 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
1413
1469
 
1414
1470
  switch (op->src[0]->type) {
1415
1471
  case GGML_TYPE_F32:
1416
- case GGML_TYPE_Q6_K:
1417
1472
  return op->ne[3] == 1;
1473
+ case GGML_TYPE_Q6_K:
1418
1474
  case GGML_TYPE_F16:
1419
1475
  case GGML_TYPE_Q8_0:
1420
1476
  case GGML_TYPE_Q4_0:
@@ -1515,9 +1571,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1515
1571
  const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
1516
1572
  uint32_t off_src0 = 0;
1517
1573
  uint32_t off_src1 = 0;
1574
+ uint32_t off_src2 = 0;
1518
1575
  uint32_t off_dst = 0;
1519
1576
  const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
1520
1577
  const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
1578
+ const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
1521
1579
  const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
1522
1580
 
1523
1581
  switch (dst->op) {
@@ -1593,11 +1651,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1593
1651
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1594
1652
  GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1595
1653
 
1596
- #pragma message("TODO: add ALiBi support")
1597
- #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1598
- GGML_ASSERT(max_bias == 0.0f);
1654
+ const int64_t nrows_x = ggml_nrows(src0);
1655
+ const int64_t nrows_y = src0->ne[1];
1656
+
1657
+ const uint32_t n_head = nrows_x/nrows_y;
1658
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1599
1659
 
1600
- ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1660
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1661
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1662
+
1663
+ ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
1601
1664
  } break;
1602
1665
  case GGML_OP_DIAG_MASK_INF:
1603
1666
  {
@@ -1649,38 +1712,44 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1649
1712
  case GGML_TYPE_F16:
1650
1713
  ggml_vk_mul_mat_f16(
1651
1714
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1652
- ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
1715
+ ne00, ne01, ne02, nb00, nb01, nb02, nb03,
1716
+ ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1653
1717
  ne0, ne1, r2, r3
1654
1718
  );
1655
1719
  break;
1656
1720
  case GGML_TYPE_Q8_0:
1657
1721
  ggml_vk_mul_mat_q8_0(
1658
1722
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1659
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1723
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1724
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1660
1725
  );
1661
1726
  break;
1662
1727
  case GGML_TYPE_Q4_0:
1663
1728
  ggml_vk_mul_mat_q4_0(
1664
1729
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1665
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1730
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1731
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1666
1732
  );
1667
1733
  break;
1668
1734
  case GGML_TYPE_Q4_1:
1669
1735
  ggml_vk_mul_mat_q4_1(
1670
1736
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1671
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1737
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1738
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1672
1739
  );
1673
1740
  break;
1674
1741
  case GGML_TYPE_Q4_K:
1675
1742
  ggml_vk_mul_mat_q4_k(
1676
1743
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1677
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
1744
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1745
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1678
1746
  );
1679
1747
  break;
1680
1748
  case GGML_TYPE_Q6_K:
1681
1749
  ggml_vk_mul_mat_q6_k(
1682
1750
  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1683
- ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
1751
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1752
+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
1684
1753
  );
1685
1754
  break;
1686
1755
  default: {
@@ -1709,13 +1778,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1709
1778
  } break;
1710
1779
  case GGML_OP_ROPE:
1711
1780
  {
1712
- #pragma message("TODO: implement phi3 frequency factors support")
1713
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1714
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1715
-
1716
- #pragma message("TODO: update rope NORM mode to match NEOX mode")
1717
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
1718
-
1719
1781
  GGML_ASSERT(ne10 == ne02);
1720
1782
  GGML_ASSERT(src0t == dstt);
1721
1783
  // const int n_past = ((int32_t *) dst->op_params)[0];
@@ -1724,6 +1786,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1724
1786
  // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1725
1787
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1726
1788
 
1789
+ const bool has_freq_factors = dst->src[2] != nullptr;
1790
+
1727
1791
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1728
1792
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1729
1793
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@@ -1732,8 +1796,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1732
1796
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1733
1797
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1734
1798
  ggml_vk_rope(
1735
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1736
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1799
+ seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
1800
+ freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
1737
1801
  ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1738
1802
  );
1739
1803
  } break;
@@ -2176,9 +2240,12 @@ static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
2176
2240
 
2177
2241
  ggml_backend_reg_t ggml_backend_kompute_reg() {
2178
2242
  static ggml_backend_reg reg = {
2179
- /* .iface = */ ggml_backend_kompute_reg_i,
2180
- /* .context = */ nullptr,
2243
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
2244
+ /* .iface = */ ggml_backend_kompute_reg_i,
2245
+ /* .context = */ nullptr,
2181
2246
  };
2182
2247
 
2183
2248
  return &reg;
2184
2249
  }
2250
+
2251
+ GGML_BACKEND_DL_IMPL(ggml_backend_kompute_reg)
@@ -4,19 +4,16 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
4
4
 
5
5
  message(STATUS "Metal framework found")
6
6
 
7
- add_library(ggml-metal
8
- ggml-metal.m
9
- )
7
+ ggml_add_backend_library(ggml-metal
8
+ ggml-metal.m
9
+ )
10
10
 
11
11
  target_link_libraries(ggml-metal PRIVATE
12
- ggml-base
13
12
  ${FOUNDATION_LIBRARY}
14
13
  ${METAL_FRAMEWORK}
15
14
  ${METALKIT_FRAMEWORK}
16
15
  )
17
16
 
18
- target_include_directories(ggml-metal PRIVATE . ..)
19
-
20
17
  if (GGML_METAL_NDEBUG)
21
18
  add_compile_definitions(GGML_METAL_NDEBUG)
22
19
  endif()
@@ -102,6 +102,21 @@ typedef struct {
102
102
  uint64_t nb3;
103
103
  } ggml_metal_kargs_cpy;
104
104
 
105
+ typedef struct {
106
+ int64_t ne10;
107
+ int64_t ne11;
108
+ int64_t ne12;
109
+ uint64_t nb10;
110
+ uint64_t nb11;
111
+ uint64_t nb12;
112
+ uint64_t nb13;
113
+ uint64_t nb1;
114
+ uint64_t nb2;
115
+ uint64_t nb3;
116
+ uint64_t offs;
117
+ bool inplace;
118
+ } ggml_metal_kargs_set;
119
+
105
120
  typedef struct {
106
121
  int32_t ne00;
107
122
  int32_t ne01;
@@ -192,6 +207,30 @@ typedef struct {
192
207
  int16_t r3;
193
208
  } ggml_metal_kargs_mul_mv;
194
209
 
210
+ typedef struct {
211
+ int32_t ne00;
212
+ int32_t ne01;
213
+ int32_t ne02;
214
+ uint64_t nb00;
215
+ uint64_t nb01;
216
+ uint64_t nb02;
217
+ uint64_t nb03;
218
+ int32_t ne10;
219
+ int32_t ne11;
220
+ int32_t ne12;
221
+ uint64_t nb10;
222
+ uint64_t nb11;
223
+ uint64_t nb12;
224
+ uint64_t nb13;
225
+ int32_t ne0;
226
+ int32_t ne1;
227
+ int16_t r2;
228
+ int16_t r3;
229
+ int16_t nsg;
230
+ int16_t nxpsg;
231
+ int16_t r1ptg;
232
+ } ggml_metal_kargs_mul_mv_ext;
233
+
195
234
  typedef struct {
196
235
  int32_t nei0;
197
236
  int32_t nei1;
@@ -20,6 +20,11 @@ find_package(MUSAToolkit)
20
20
  if (MUSAToolkit_FOUND)
21
21
  message(STATUS "MUSA Toolkit found")
22
22
 
23
+ if (NOT DEFINED MUSA_ARCHITECTURES)
24
+ set(MUSA_ARCHITECTURES "21;22")
25
+ endif()
26
+ message(STATUS "Using MUSA architectures: ${MUSA_ARCHITECTURES}")
27
+
23
28
  file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
24
29
  list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
25
30
 
@@ -44,15 +49,17 @@ if (MUSAToolkit_FOUND)
44
49
 
45
50
  set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
46
51
  foreach(SOURCE ${GGML_SOURCES_MUSA})
47
- set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
52
+ set(COMPILE_FLAGS "-x musa -mtgpu")
53
+ foreach(ARCH ${MUSA_ARCHITECTURES})
54
+ set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
55
+ endforeach()
56
+ set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS ${COMPILE_FLAGS})
48
57
  endforeach()
49
58
 
50
- add_library(ggml-musa
51
- ${GGML_HEADERS_MUSA}
52
- ${GGML_SOURCES_MUSA})
53
-
54
- target_link_libraries(ggml-musa PRIVATE ggml-base)
55
- target_include_directories(ggml-musa PRIVATE . ..)
59
+ ggml_add_backend_library(ggml-musa
60
+ ${GGML_HEADERS_MUSA}
61
+ ${GGML_SOURCES_MUSA}
62
+ )
56
63
 
57
64
  # TODO: do not use CUDA definitions for MUSA
58
65
  target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)