@fugood/llama.node 0.3.3 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +29 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +17 -1
  21. package/src/LlamaContext.cpp +86 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -813,7 +813,7 @@ load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
813
813
  x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
814
814
  }
815
815
 
816
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
816
+ constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256
817
817
  const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
818
818
 
819
819
  #pragma unroll
@@ -961,7 +961,7 @@ load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
961
961
  x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
962
962
  }
963
963
 
964
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
964
+ constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256
965
965
  const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
966
966
 
967
967
  #pragma unroll
@@ -1109,7 +1109,7 @@ load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
1109
1109
  dpct::sub_sat());
1110
1110
  }
1111
1111
 
1112
- const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1112
+ constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1113
1113
  const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
1114
1114
  float * x_dmf = (float *) x_dm;
1115
1115
 
@@ -3020,9 +3020,9 @@ void ggml_sycl_op_mul_mat_q(
3020
3020
  break;
3021
3021
  }
3022
3022
 
3023
- (void) src1;
3024
- (void) dst;
3025
- (void) src1_ddf_i;
3023
+ GGML_UNUSED(src1);
3024
+ GGML_UNUSED(dst);
3025
+ GGML_UNUSED(src1_ddf_i);
3026
3026
  }
3027
3027
  catch (sycl::exception const &exc) {
3028
3028
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -753,11 +753,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
753
753
  const sycl::range<3> block_nums(1, 1, block_num_y);
754
754
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
755
755
  {
756
-
757
- stream->submit([&](sycl::handler &cgh) {
758
- auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
759
- auto ksigns64_ptr_ct1 = &ksigns64[0];
760
-
756
+ stream->submit([&](sycl::handler & cgh) {
761
757
  cgh.parallel_for(
762
758
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
763
759
  [=](sycl::nd_item<3> item_ct1)
@@ -780,9 +776,6 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
780
776
  {
781
777
 
782
778
  stream->submit([&](sycl::handler &cgh) {
783
- auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
784
- auto ksigns64_ptr_ct1 = &ksigns64[0];
785
-
786
779
  cgh.parallel_for(
787
780
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
788
781
  [=](sycl::nd_item<3> item_ct1)
@@ -805,9 +798,6 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
805
798
  {
806
799
 
807
800
  stream->submit([&](sycl::handler &cgh) {
808
- auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0];
809
- auto ksigns64_ptr_ct1 = &ksigns64[0];
810
-
811
801
  cgh.parallel_for(
812
802
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
813
803
  [=](sycl::nd_item<3> item_ct1)
@@ -830,8 +820,6 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
830
820
  {
831
821
 
832
822
  stream->submit([&](sycl::handler &cgh) {
833
- auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
834
-
835
823
  cgh.parallel_for(
836
824
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
837
825
  [=](sycl::nd_item<3> item_ct1)
@@ -854,9 +842,6 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
854
842
  {
855
843
 
856
844
  stream->submit([&](sycl::handler &cgh) {
857
- auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0];
858
- auto ksigns64_ptr_ct1 = &ksigns64[0];
859
-
860
845
  cgh.parallel_for(
861
846
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
862
847
  [=](sycl::nd_item<3> item_ct1)
@@ -954,7 +939,7 @@ void ggml_sycl_op_mul_mat_vec_q(
954
939
  const size_t q8_1_bs = QK8_1;
955
940
  // the main device has a larger memory buffer to hold the results from all GPUs
956
941
  // nrows_dst == nrows of the matrix that the kernel writes into
957
- const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff;
942
+
958
943
  for (int i = 0; i < src1_ncols; i++)
959
944
  {
960
945
  const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
@@ -1023,7 +1008,8 @@ void ggml_sycl_op_mul_mat_vec_q(
1023
1008
  break;
1024
1009
  }
1025
1010
  }
1026
- (void) src1;
1027
- (void) dst;
1028
- (void) src1_ddf_i;
1011
+ GGML_UNUSED(src1);
1012
+ GGML_UNUSED(dst);
1013
+ GGML_UNUSED(src1_ddf_i);
1014
+ GGML_UNUSED(ctx);
1029
1015
  }
@@ -31,7 +31,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
31
31
  */
32
32
  item_ct1.barrier(sycl::access::fence_space::local_space);
33
33
  mean_var = 0.f;
34
- int nreduce = nwarps / WARP_SIZE;
34
+ size_t nreduce = nwarps / WARP_SIZE;
35
35
  for (size_t i = 0; i < nreduce; i += 1)
36
36
  {
37
37
  mean_var += s_sum[lane_id + i * WARP_SIZE];
@@ -55,7 +55,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
55
55
  const int nthreads = item_ct1.get_local_range(2);
56
56
  const int nwarps = nthreads / WARP_SIZE;
57
57
  start += item_ct1.get_local_id(2);
58
- int nreduce = nwarps / WARP_SIZE;
58
+ size_t nreduce = nwarps / WARP_SIZE;
59
59
 
60
60
  if (end >= ne_elements) {
61
61
  end = ne_elements;
@@ -163,7 +163,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
163
163
  converged control flow. You may need to adjust the code.
164
164
  */
165
165
  item_ct1.barrier(sycl::access::fence_space::local_space);
166
- int nreduce = nwarps / WARP_SIZE;
166
+ size_t nreduce = nwarps / WARP_SIZE;
167
167
  tmp = 0.f;
168
168
  for (size_t i = 0; i < nreduce; i += 1)
169
169
  {
@@ -352,6 +352,7 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
352
352
  (void)src1;
353
353
  (void)dst;
354
354
  (void)src1_dd;
355
+ GGML_UNUSED(ctx);
355
356
  }
356
357
 
357
358
  void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
@@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
40
40
 
41
41
  try {
42
42
  // Perform matrix multiplication using oneMKL GEMM
43
- oneapi::mkl::blas::column_major::gemm(*stream,
44
- oneapi::mkl::transpose::nontrans, src1_op,
45
- ne0, ne1, ne01,
46
- alpha,
47
- src0_d, ne00,
48
- src1_d, ldb,
49
- beta,
50
- dst_d, ne0);
43
+ #ifdef GGML_SYCL_NVIDIA
44
+ oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
45
+ oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
46
+ ne00, src1_d, ldb, beta, dst_d, ne0);
47
+ #else
48
+ oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
49
+ src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
50
+ #endif
51
51
  }
52
52
  catch (sycl::exception const& exc) {
53
53
  std::cerr << exc.what() << std::endl;
@@ -269,7 +269,8 @@ void ggml_sycl_op_rope(
269
269
  }
270
270
  }
271
271
 
272
- (void) src1;
273
- (void) dst;
274
- (void) src1_dd;
272
+ GGML_UNUSED(src1);
273
+ GGML_UNUSED(dst);
274
+ GGML_UNUSED(src1_dd);
275
+ GGML_UNUSED(ctx);
275
276
  }
@@ -16,7 +16,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
16
16
  const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
17
17
  const int nthreads = block_size;
18
18
  const int nwarps = nthreads / WARP_SIZE;
19
- int nreduce = nwarps / WARP_SIZE;
19
+ size_t nreduce = nwarps / WARP_SIZE;
20
20
  float slope = 1.0f;
21
21
 
22
22
  // ALiBi
@@ -53,8 +53,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
53
53
  if (block_size > WARP_SIZE) {
54
54
  if (warp_id == 0) {
55
55
  buf[lane_id] = -INFINITY;
56
- for (size_t i = 1; i < nreduce; i += 1)
56
+ for (size_t i = 1; i < nreduce; i += 1) {
57
57
  buf[lane_id + i * WARP_SIZE] = -INFINITY;
58
+ }
58
59
  }
59
60
  item_ct1.barrier(sycl::access::fence_space::local_space);
60
61
 
@@ -63,8 +64,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
63
64
  }
64
65
  item_ct1.barrier(sycl::access::fence_space::local_space);
65
66
  max_val = buf[lane_id];
66
- for (size_t i = 1; i < nreduce; i += 1)
67
- {
67
+ for (size_t i = 1; i < nreduce; i += 1) {
68
68
  max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
69
69
  }
70
70
  max_val = warp_reduce_max(max_val, item_ct1);
@@ -89,8 +89,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
89
89
  item_ct1.barrier(sycl::access::fence_space::local_space);
90
90
  if (warp_id == 0) {
91
91
  buf[lane_id] = 0.f;
92
- for (size_t i = 1; i < nreduce; i += 1)
92
+ for (size_t i = 1; i < nreduce; i += 1) {
93
93
  buf[lane_id + i * WARP_SIZE] = 0.f;
94
+ }
94
95
  }
95
96
  item_ct1.barrier(sycl::access::fence_space::local_space);
96
97
 
@@ -100,8 +101,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
100
101
  item_ct1.barrier(sycl::access::fence_space::local_space);
101
102
 
102
103
  tmp = buf[lane_id];
103
- for (size_t i = 1; i < nreduce; i += 1)
104
- {
104
+ for (size_t i = 1; i < nreduce; i += 1) {
105
105
  tmp += buf[lane_id + i * WARP_SIZE];
106
106
  }
107
107
  tmp = warp_reduce_sum(tmp, item_ct1);
@@ -68,4 +68,5 @@ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml
68
68
  const int max_period = dst->op_params[1];
69
69
 
70
70
  timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
71
+ GGML_UNUSED(src1);
71
72
  }
@@ -59,7 +59,7 @@ static void rwkv_wkv_f32_kernel(
59
59
  float y = 0;
60
60
 
61
61
  // Process in chunks of 4 for better vectorization
62
- sycl::float4 k4, r4, tf4, td4, s4, kv4;
62
+ sycl::float4 k4, r4, tf4, td4, s4;
63
63
  #pragma unroll
64
64
  for (int j = 0; j < head_size; j += 4) {
65
65
  // Load data in vec4 chunks
@@ -135,4 +135,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
135
135
  );
136
136
  });
137
137
  });
138
+
139
+ GGML_UNUSED(src0);
140
+ GGML_UNUSED(src1);
138
141
  }
@@ -1,11 +1,13 @@
1
1
  #pragma once
2
2
 
3
+ #include "ggml.h"
4
+
3
5
  #ifdef __cplusplus
4
6
  extern "C" {
5
7
  #endif
6
8
 
7
- void ggml_critical_section_start(void);
8
- void ggml_critical_section_end(void);
9
+ GGML_API void ggml_critical_section_start(void);
10
+ GGML_API void ggml_critical_section_end(void);
9
11
 
10
12
  #ifdef __cplusplus
11
13
  }
@@ -3,13 +3,27 @@ find_package(Vulkan COMPONENTS glslc REQUIRED)
3
3
  if (Vulkan_FOUND)
4
4
  message(STATUS "Vulkan found")
5
5
 
6
- add_library(ggml-vulkan
7
- ggml-vulkan.cpp
8
- ../../include/ggml-vulkan.h
9
- )
6
+ ggml_add_backend_library(ggml-vulkan
7
+ ggml-vulkan.cpp
8
+ ../../include/ggml-vulkan.h
9
+ )
10
+
11
+ # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
12
+ # If it's not, there will be an error to stderr.
13
+ # If it's supported, set a define to indicate that we should compile those shaders
14
+ execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
15
+ OUTPUT_VARIABLE glslc_output
16
+ ERROR_VARIABLE glslc_error)
17
+
18
+ if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
19
+ message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
20
+ else()
21
+ message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
22
+ add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
23
+ endif()
10
24
 
11
- target_link_libraries(ggml-vulkan PRIVATE ggml-base Vulkan::Vulkan)
12
- target_include_directories(ggml-vulkan PRIVATE . .. ${CMAKE_CURRENT_BINARY_DIR})
25
+ target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
26
+ target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
13
27
 
14
28
  # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
15
29
  # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
@@ -67,7 +81,7 @@ if (Vulkan_FOUND)
67
81
  --target-cpp ${_ggml_vk_source}
68
82
  --no-clean
69
83
 
70
- DEPENDS ${_ggml_vk_shader_deps}
84
+ DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd}
71
85
  COMMENT "Generate vulkan shaders"
72
86
  )
73
87