@fugood/llama.node 0.3.17 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -49,6 +49,8 @@ static bool g_sycl_loaded = false;
49
49
  int g_ggml_sycl_debug = 0;
50
50
  int g_ggml_sycl_disable_optimize = 0;
51
51
  int g_ggml_sycl_disable_graph = 0;
52
+ int g_ggml_sycl_disable_dnn = 0;
53
+ int g_ggml_sycl_prioritize_dmmv = 0;
52
54
 
53
55
  static ggml_sycl_device_info ggml_sycl_init() {
54
56
  ggml_sycl_device_info info = {};
@@ -193,13 +195,25 @@ static void ggml_check_sycl() try {
193
195
 
194
196
  if (!initialized) {
195
197
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
196
- g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
198
+ g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
197
199
  g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
200
+ g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
201
+ g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
198
202
  GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
199
203
  GGML_LOG_INFO("Running with Environment Variables:\n");
200
204
  GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
201
205
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
206
+ #ifdef GGML_SYCL_GRAPH
202
207
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
208
+ #else
209
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
210
+ #endif
211
+ #if GGML_SYCL_DNNL
212
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
213
+ #else
214
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
215
+ #endif
216
+ GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
203
217
  GGML_LOG_INFO("Build with Macros:\n");
204
218
  #if defined(GGML_SYCL_FORCE_MMQ)
205
219
  GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
@@ -338,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
338
352
  assert(tensor->view_src->buffer->buft == buffer->buft);
339
353
  return GGML_STATUS_SUCCESS;
340
354
  }
341
- if (tensor->type == GGML_TYPE_Q4_0) {
355
+ if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
342
356
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
343
357
  tensor->extra = extra;
344
358
  ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -1982,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
1982
1996
 
1983
1997
  const int64_t ne00 = src0->ne[0];
1984
1998
  const int64_t ne10 = src1->ne[0];
1985
-
1999
+ GGML_ASSERT(ne00 == ne10);
1986
2000
 
1987
2001
  const int64_t row_diff = row_high - row_low;
1988
2002
 
1989
2003
  int id;
1990
2004
  SYCL_CHECK(
1991
2005
  CHECK_TRY_ERROR(id = get_current_device_id()));
1992
- #if !GGML_SYCL_DNNL
1993
- const int64_t ne0 = dst->ne[0];
2006
+
2007
+ const int64_t ne0 = dst->ne[0]; // used by MKL only
1994
2008
  // the main device has a larger memory buffer to hold the results from all GPUs
1995
2009
  // ldc == nrows of the matrix that cuBLAS writes into
1996
- int ldc = id == ctx.device ? ne0 : row_diff;
1997
- #endif
2010
+ int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
1998
2011
 
1999
2012
  #ifdef GGML_SYCL_F16
2000
2013
  bool use_fp16 = true; // TODO(Yu) SYCL capability check
@@ -2030,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
2030
2043
  : src1_as_f16.get();
2031
2044
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2032
2045
 
2033
- #if !GGML_SYCL_DNNL
2034
- const sycl::half alpha_f16 = 1.0f;
2035
- const sycl::half beta_f16 = 0.0f;
2036
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2037
- *stream, oneapi::math::transpose::trans,
2038
- oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2039
- &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2040
- src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2041
- dst_f16.get(), dpct::library_data_t::real_half, ldc,
2042
- dpct::library_data_t::real_half)));
2043
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2044
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2045
- #else
2046
- DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
2047
- DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2048
- dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2049
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2050
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2046
+ #if GGML_SYCL_DNNL
2047
+ if (!g_ggml_sycl_disable_dnn) {
2048
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2049
+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2050
+ dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2051
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2052
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2053
+ }
2054
+ else
2051
2055
  #endif
2056
+ {
2057
+ const sycl::half alpha_f16 = 1.0f;
2058
+ const sycl::half beta_f16 = 0.0f;
2059
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2060
+ *stream, oneapi::math::transpose::trans,
2061
+ oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2062
+ &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2063
+ src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2064
+ dst_f16.get(), dpct::library_data_t::real_half, ldc,
2065
+ dpct::library_data_t::real_half)));
2066
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2067
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2068
+ }
2052
2069
  }
2053
2070
  else {
2054
2071
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2069,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
2069
2086
  const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
2070
2087
  const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
2071
2088
 
2072
- #if !GGML_SYCL_DNNL
2073
- const float alpha = 1.0f;
2074
- const float beta = 0.0f;
2075
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2076
- get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2077
- src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2078
- dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2079
- #else
2080
- DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
2081
- DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2082
- dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2089
+ #if GGML_SYCL_DNNL
2090
+ if (!g_ggml_sycl_disable_dnn) {
2091
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2092
+ DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2093
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2094
+ }
2095
+ else
2083
2096
  #endif
2097
+ {
2098
+ const float alpha = 1.0f;
2099
+ const float beta = 0.0f;
2100
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2101
+ get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2102
+ src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2103
+ dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2104
+ }
2084
2105
  }
2085
2106
  GGML_UNUSED(dst);
2086
2107
  GGML_UNUSED(src1_ddq_i);
@@ -2694,139 +2715,180 @@ catch (sycl::exception const &exc) {
2694
2715
  std::exit(1);
2695
2716
  }
2696
2717
 
2697
- static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
2698
- const sycl::half *src1_as_f16, char *dst,
2699
- const void **ptrs_src, void **ptrs_dst,
2700
- int64_t ne12, int64_t ne13, int64_t ne23,
2701
- size_t nb02, size_t nb03, size_t nb12,
2702
- size_t nb13, size_t nbd2, size_t nbd3,
2703
- int64_t r2, int64_t r3,
2704
- const sycl::nd_item<3> &item_ct1) {
2705
- int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
2706
- item_ct1.get_local_id(2);
2707
- int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
2708
- item_ct1.get_local_id(1);
2718
+ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
2719
+ const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2720
+ size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2721
+ int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
2722
+ const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
2723
+ const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
2709
2724
 
2710
2725
  if (i13 >= ne13 || i12 >= ne12) {
2711
2726
  return;
2712
2727
  }
2713
2728
 
2714
- int64_t i03 = i13 / r3;
2715
- int64_t i02 = i12 / r2;
2729
+ const int64_t i03 = i13 / r3;
2730
+ const int64_t i02 = i12 / r2;
2731
+
2732
+ const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2733
+ const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2734
+ uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
2716
2735
 
2717
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
2718
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
2719
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
2736
+ ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2737
+ ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
2738
+ ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
2720
2739
  }
2721
2740
 
2722
- static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2723
- const ggml_tensor *src0,
2724
- const ggml_tensor *src1,
2725
- ggml_tensor *dst) try {
2741
+ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
2742
+ const ggml_tensor * src1, ggml_tensor * dst) try {
2726
2743
  GGML_ASSERT(!ggml_is_transposed(src0));
2727
2744
  GGML_ASSERT(!ggml_is_transposed(src1));
2728
2745
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2729
2746
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2747
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2730
2748
 
2731
2749
  GGML_TENSOR_BINARY_OP_LOCALS
2732
2750
 
2751
+ // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
2752
+ // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
2753
+ GGML_ASSERT(ggml_is_contiguous(dst));
2733
2754
 
2734
2755
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2735
- queue_ptr main_stream = ctx.stream();;
2756
+ queue_ptr queue = ctx.stream();
2736
2757
 
2737
- void * src0_ddq = src0->data;
2738
- sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
2739
- float * src1_ddf = (float *) src1->data;
2740
- float * dst_ddf = (float *) dst->data;
2758
+ dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
2741
2759
 
2742
- // convert src1 to fp16
2760
+ const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
2761
+ float * dst_ddf = static_cast<float *>(dst->data);
2762
+
2763
+ const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2764
+ const size_t type_size_src1 = ggml_type_size(src1->type);
2765
+ GGML_ASSERT(nb10 == type_size_src1);
2766
+
2767
+ // SRC1 strides
2768
+ int64_t s11 = nb11 / type_size_src1;
2769
+ int64_t s12 = nb12 / type_size_src1;
2770
+ int64_t s13 = nb13 / type_size_src1;
2743
2771
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
2772
+
2773
+ // convert src1 to fp16
2744
2774
  if (src1->type != GGML_TYPE_F16) {
2745
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2775
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2776
+ GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2746
2777
  const int64_t ne_src1 = ggml_nelements(src1);
2747
2778
  src1_f16_alloc.alloc(ne_src1);
2748
- GGML_ASSERT(to_fp16_sycl != nullptr);
2749
- to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
2779
+ to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2780
+
2781
+ src1_f16 = src1_f16_alloc.get();
2782
+ s11 = ne10;
2783
+ s12 = ne11 * s11;
2784
+ s13 = ne12 * s12;
2750
2785
  }
2751
- sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
2752
- : src1_f16_alloc.get();
2753
2786
 
2754
- char * dst_t;
2787
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
2755
2788
 
2756
- dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
2757
- dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
2789
+ dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
2790
+ dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
2758
2791
 
2759
2792
  // dst strides
2760
2793
  size_t nbd2 = dst->nb[2];
2761
2794
  size_t nbd3 = dst->nb[3];
2762
2795
 
2763
2796
  const float alpha_f32 = 1.0f;
2764
- const float beta_f32 = 0.0f;
2797
+ const float beta_f32 = 0.0f;
2765
2798
 
2766
2799
  const void * alpha = &alpha_f32;
2767
2800
  const void * beta = &beta_f32;
2768
2801
 
2769
- dst_t = (char *) dst_ddf;
2770
-
2771
2802
  GGML_ASSERT(ne12 % ne02 == 0);
2772
2803
  GGML_ASSERT(ne13 % ne03 == 0);
2804
+ GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
2805
+ GGML_ASSERT(ne10 == ne00);
2773
2806
 
2774
2807
  // broadcast factors
2775
- const int64_t r2 = ne12/ne02;
2776
- const int64_t r3 = ne13/ne03;
2777
-
2778
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2779
- // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2780
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2781
- *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2782
- (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2783
- (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
2784
- cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
2785
- } else {
2786
- const int ne23 = ne12*ne13;
2787
-
2788
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
2789
- ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
2790
- ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2791
-
2792
- sycl::range<3> block_dims(1, ne12, ne13);
2793
- /*
2794
- DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
2795
- the limit. To get the device limit, query
2796
- info::device::max_work_group_size. Adjust the work-group size if needed.
2797
- */
2798
- {
2799
- dpct::has_capability_or_fail(main_stream->get_device(),
2800
- {sycl::aspect::fp16});
2801
-
2802
- main_stream->submit([&](sycl::handler &cgh) {
2803
- const void **ptrs_src_get = ptrs_src.get();
2804
- void **ptrs_dst_get = ptrs_dst.get();
2805
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
2806
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
2807
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
2808
- [=](sycl::nd_item<3> item_ct1) {
2809
- k_compute_batched_ptrs(
2810
- src0_as_f16, src1_f16,
2811
- dst_t, ptrs_src_get,
2812
- ptrs_dst_get, ne12, ne13, ne23,
2813
- nb02, nb03, nb12_scaled, nb13_scaled,
2814
- nbd2, nbd3, r2, r3, item_ct1);
2815
- });
2808
+ const int64_t r2 = ne12 / ne02;
2809
+ const int64_t r3 = ne13 / ne03;
2810
+
2811
+ #if GGML_SYCL_DNNL
2812
+ if (!g_ggml_sycl_disable_dnn) {
2813
+ auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2814
+ (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2815
+
2816
+ DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2817
+ src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2818
+ src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2819
+ dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2820
+ };
2821
+
2822
+ if (r2 == 1 && r3 == 1) {
2823
+ if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2824
+ dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2825
+ }
2826
+ else {
2827
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2828
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2829
+ const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2830
+ float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2831
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2832
+ }
2833
+ }
2834
+ } else {
2835
+ // iterate over batches from smaller set of matrices (matrix 0)
2836
+ for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2837
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2838
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2839
+ const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2840
+ float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2841
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
2842
+ }
2843
+ }
2844
+ }
2845
+ }
2846
+ else
2847
+ #endif
2848
+ {
2849
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2850
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2851
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2852
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2853
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2854
+ src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2855
+ mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2856
+ } else {
2857
+ const int ne23 = ne12 * ne13;
2858
+
2859
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
2860
+ ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
2861
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2862
+
2863
+ sycl::range<3> block_dims(1, ne12, ne13);
2864
+ queue->submit([&](sycl::handler & cgh) {
2865
+ const void ** ptrs_src_get = ptrs_src.get();
2866
+ void ** ptrs_dst_get = ptrs_dst.get();
2867
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2868
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2869
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2870
+ k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2871
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2872
+ });
2816
2873
  });
2874
+
2875
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2876
+ *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2877
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2878
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
2879
+ (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
2817
2880
  }
2818
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2819
- *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2820
- (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2821
- (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
2822
- (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
2823
2881
  }
2882
+ } catch (const sycl::exception & exc) {
2883
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2884
+ std::exit(1);
2824
2885
  }
2825
- catch (sycl::exception const &exc) {
2826
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2827
- << ", line:" << __LINE__ << std::endl;
2828
- std::exit(1);
2829
- }
2886
+
2887
+ enum class mul_mat_algo {
2888
+ DMMV = 0,
2889
+ MMVQ = 1,
2890
+ MUL_MAT_SYCL = 2,
2891
+ };
2830
2892
 
2831
2893
  inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
2832
2894
  // TODO: accuracy issues in MMQ
@@ -2834,6 +2896,36 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
2834
2896
  return false;
2835
2897
  }
2836
2898
 
2899
+ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
2900
+ switch (type) {
2901
+ case GGML_TYPE_Q4_0:
2902
+ return true;
2903
+ case GGML_TYPE_Q4_K:
2904
+ return !g_ggml_sycl_prioritize_dmmv;
2905
+ default:
2906
+ return false;
2907
+ }
2908
+ }
2909
+
2910
+ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
2911
+ switch (type) {
2912
+ case GGML_TYPE_Q4_0:
2913
+ return true;
2914
+ default:
2915
+ return false;
2916
+ }
2917
+ }
2918
+
2919
+ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
2920
+ switch (type) {
2921
+ case GGML_TYPE_Q4_0:
2922
+ case GGML_TYPE_Q4_K:
2923
+ return true;
2924
+ default:
2925
+ return false;
2926
+ }
2927
+ }
2928
+
2837
2929
  static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2838
2930
  switch (type) {
2839
2931
  case GGML_TYPE_Q4_0:
@@ -2853,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2853
2945
  }
2854
2946
  }
2855
2947
 
2856
- static void reorder_qw(char *data_device, const int ncols, const int nrows,
2857
- size_t size, size_t offset, dpct::queue_ptr stream) {
2858
- auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
2948
+ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
2949
+ dpct::queue_ptr stream) {
2950
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2859
2951
  SYCL_CHECK(
2860
2952
  CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
2861
2953
  .wait()));
2862
2954
  GGML_ASSERT((size % sizeof(block_q4_0) == 0));
2863
2955
  GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
2864
2956
  int offset_blks = offset / sizeof(block_q4_0);
2865
- auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
2957
+ auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
2866
2958
  auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
2867
2959
 
2868
2960
  stream->parallel_for(
@@ -2876,48 +2968,119 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
2876
2968
  *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
2877
2969
  }
2878
2970
  *(d_ptr + ib) = x[ib].d;
2879
- });
2971
+ }).wait_and_throw();
2972
+
2973
+ sycl::free(tmp_buf, *stream);
2974
+ }
2975
+
2976
+ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
2977
+ GGML_ASSERT(size % sizeof(block_q4_K) == 0);
2978
+ GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
2979
+
2980
+ const int nblocks = size / sizeof(block_q4_K);
2981
+
2982
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2983
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
2984
+
2985
+ auto * qs_ptr = data_device;
2986
+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
2987
+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
2988
+
2989
+ stream->parallel_for(nblocks, [=](auto i) {
2990
+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
2991
+ const int ib = i;
2992
+
2993
+ for (int j = 0; j < QK_K / 2; ++j) {
2994
+ qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
2995
+ }
2996
+
2997
+ for (int j = 0; j < K_SCALE_SIZE; ++j) {
2998
+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
2999
+ }
3000
+
3001
+ dm_ptr[ib] = x[ib].dm;
3002
+ }).wait_and_throw();
2880
3003
 
2881
3004
  sycl::free(tmp_buf, *stream);
2882
3005
  }
2883
3006
 
2884
3007
  static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
2885
- char*data_device = (char*)src0->data;
3008
+ uint8_t * data_device = (uint8_t *) src0->data;
2886
3009
  size_t ncols = src0->ne[0];
2887
3010
  size_t nrows = src0->ne[1];
2888
3011
  size_t size = ggml_nbytes(src0);
2889
3012
 
2890
- reorder_qw(data_device, ncols, nrows, size, 0, stream);
3013
+ switch (src0->type) {
3014
+ case GGML_TYPE_Q4_0:
3015
+ reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
3016
+ break;
3017
+ case GGML_TYPE_Q4_K:
3018
+ reorder_qw_q4_k(data_device, size, 0, stream);
3019
+ break;
3020
+ default:
3021
+ GGML_ABORT("reorder_qw() called with unsupported type");
3022
+ break;
3023
+ }
2891
3024
  }
2892
3025
 
2893
- /*
2894
- * This function could be called when the OP (mul_mat) function support reorder optimizition.
2895
- */
2896
- static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2897
- ggml_tensor * dst) {
2898
- if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
2899
- ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
2900
- dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
2901
- src0->type == GGML_TYPE_Q4_0 &&
2902
- src1->ne[2]==1 && src1->ne[3]==1) {
3026
+ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
3027
+ return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
3028
+ ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
3029
+ dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
3030
+ dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
3031
+ }
2903
3032
 
2904
- ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
2905
- if (!extra) return; //only happen in CI/UT permute case.
3033
+ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
3034
+ ggml_tensor * dst, mul_mat_algo mm_algorithm) {
3035
+ if (!should_reorder_tensor(*ctx, dst)) {
3036
+ return;
3037
+ }
2906
3038
 
2907
- if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder.
3039
+ ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3040
+ if (!extra || extra->optimized_feature.reorder) {
3041
+ return; // Skip permutations and already reordered tensors
3042
+ }
2908
3043
 
2909
- reorder_qw(src0, ctx->stream());
2910
- extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
3044
+ switch (mm_algorithm) {
3045
+ case mul_mat_algo::DMMV:
3046
+ if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
3047
+ return;
3048
+ }
3049
+ break;
3050
+ case mul_mat_algo::MMVQ:
3051
+ if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
3052
+ return;
3053
+ }
3054
+ break;
3055
+ case mul_mat_algo::MUL_MAT_SYCL:
3056
+ if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
3057
+ return;
3058
+ }
3059
+ break;
2911
3060
  }
3061
+
3062
+ reorder_qw(src0, ctx->stream());
3063
+ extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
2912
3064
  }
2913
3065
 
2914
- static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2915
3066
 
3067
+ static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3068
+ return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3069
+ src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3070
+ }
3071
+
3072
+ static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3073
+ return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3074
+ src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3075
+ }
3076
+
3077
+ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2916
3078
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2917
3079
  int64_t min_compute_capability = INT_MAX;
2918
3080
 
2919
3081
  if (split) {
2920
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
3082
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
3083
+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
2921
3084
  auto & tensor_split = buft_ctx->tensor_split;
2922
3085
  for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
2923
3086
  // skip devices that are not going to do any work:
@@ -2930,17 +3093,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2930
3093
  }
2931
3094
  }
2932
3095
  } else {
2933
- min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
3096
+ min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
2934
3097
  }
2935
3098
 
2936
3099
  // check data types and tensor shapes for custom matrix multiplication kernels:
2937
- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
2938
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2939
- && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3100
+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
2940
3101
 
2941
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
2942
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2943
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3102
+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
2944
3103
 
2945
3104
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
2946
3105
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -2952,9 +3111,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2952
3111
  use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
2953
3112
  #endif // SYCL_USE_XMX
2954
3113
 
3114
+
2955
3115
  // mmvq path is faster in the CUDA backend.
2956
- if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
3116
+ if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
3117
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3118
+ // is enabled takes precedence over DMMV, the current if-else implementation
3119
+ // requires disabling DMMV if both conditions are met
3120
+ || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
2957
3121
  use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3122
+ }
2958
3123
 
2959
3124
  if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
2960
3125
  // TODO: Refactor and cleanup of mul mat dispatching.
@@ -2966,24 +3131,30 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2966
3131
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
2967
3132
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
2968
3133
  }
2969
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3134
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
2970
3135
  // KQV single-batch
2971
3136
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
2972
3137
  } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2973
3138
  // KQ + KQV multi-batch
2974
3139
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
2975
3140
  } else if (use_dequantize_mul_mat_vec) {
2976
- opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
2977
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
2978
- // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
3141
+ constexpr bool convert_src1_to_q8_1 = false;
3142
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3143
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
2979
3144
  } else if (use_mul_mat_vec_q) {
2980
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
3145
+ constexpr bool convert_src1_to_q8_1 = true;
3146
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3147
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
2981
3148
  } else if (use_mul_mat_q) {
2982
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
3149
+ constexpr bool convert_src1_to_q8_1 = true;
3150
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
2983
3151
  } else {
2984
- opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
2985
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
3152
+ constexpr bool convert_src1_to_q8_1 = false;
3153
+ // MUL_MAT_SYCL supports reorder
3154
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
3155
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
2986
3156
  }
3157
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
2987
3158
  }
2988
3159
 
2989
3160
 
@@ -3651,7 +3822,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
3651
3822
  return GGML_STATUS_SUCCESS;
3652
3823
  }
3653
3824
 
3654
- sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
3825
+ sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
3826
+
3655
3827
  model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
3656
3828
  ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3657
3829
  model_sycl_graph.end_recording();