@fugood/llama.node 0.3.17 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -42,6 +42,7 @@ void ggml_sycl_host_free(void* ptr);
42
42
 
43
43
  extern int g_ggml_sycl_debug;
44
44
  extern int g_ggml_sycl_disable_optimize;
45
+ extern int g_ggml_sycl_prioritize_dmmv;
45
46
 
46
47
  #define GGML_SYCL_DEBUG(...) \
47
48
  do { \
@@ -80,10 +81,6 @@ extern int g_ggml_sycl_disable_optimize;
80
81
  // max batch size to use MMQ kernels when tensor cores are available
81
82
  #define MMQ_MAX_BATCH_SIZE 32
82
83
 
83
- #if defined(_MSC_VER)
84
- #pragma warning(disable : 4244 4267) // possible loss of data
85
- #endif
86
-
87
84
  // dmmv = dequantize_mul_mat_vec
88
85
  #ifndef GGML_SYCL_DMMV_X
89
86
  #define GGML_SYCL_DMMV_X 32
@@ -118,17 +115,12 @@ static void crash() {
118
115
  GGML_ABORT("SYCL error");
119
116
  }
120
117
 
121
- #define SYCL_CHECK(err) \
122
- do { \
123
- auto err_ = (err); \
124
- if (err_ != 0) \
125
- ggml_sycl_error( \
126
- #err, \
127
- __func__, \
128
- __FILE__, \
129
- __LINE__, \
130
- "Meet error in this line code!"); \
131
- } while (0)
118
+ #define SYCL_CHECK(err) \
119
+ do { \
120
+ auto err_ = (err); \
121
+ if (err_ != 0) \
122
+ ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \
123
+ } while (0)
132
124
 
133
125
  #if DPCT_COMPAT_RT_VERSION >= 11100
134
126
  #define GGML_SYCL_ASSUME(x) __builtin_assume(x)
@@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
183
183
  }
184
184
  }
185
185
 
186
+ template <typename dst_t>
187
+ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
188
+ const int64_t nb = k / QK_K;
189
+ const size_t local_size = 32;
190
+ const size_t global_size = nb * local_size;
191
+
192
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
193
+
194
+ stream->submit([&](sycl::handler & cgh) {
195
+ sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
196
+
197
+ cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
198
+ [=](sycl::nd_item<1> item_ct1) {
199
+ dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
200
+ });
201
+ });
202
+ }
203
+
186
204
  template <typename dst_t>
187
205
  static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
188
206
  dpct::queue_ptr stream) {
@@ -437,41 +455,52 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
437
455
  }
438
456
 
439
457
  template <typename src_t, typename dst_t>
440
- static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
441
- const sycl::nd_item<3> &item_ct1) {
458
+ static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
459
+ const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
460
+ const sycl::nd_item<3> & item_ct1) {
461
+
442
462
  const int64_t work_group_size = item_ct1.get_local_range(2);
443
- const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
463
+ const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
464
+
465
+ const int64_t i01 = item_ct1.get_group(1);
466
+ const int64_t i02 = item_ct1.get_group(0) % ne02;
467
+ const int64_t i03 = item_ct1.get_group(0) / ne02;
444
468
 
445
469
  // make each work-item deal with more elements since sycl global range can not exceed max int
446
- const src_t * x = (const src_t *) vx;
447
- for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
448
- y[i] = x[i];
470
+ const src_t * x = static_cast<const src_t *>(vx);
471
+ const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;
472
+ const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;
473
+
474
+ #pragma unroll
475
+ for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {
476
+ y[iy + i00] = static_cast<dst_t>(x[ix + i00]);
449
477
  }
450
478
  }
451
479
 
452
480
  template <typename src_t, typename dst_t>
453
- static void convert_unary_sycl(const void *__restrict__ vx,
454
- dst_t *__restrict__ y, const int64_t k,
455
- dpct::queue_ptr stream) {
456
- const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
481
+ static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,
482
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
483
+ const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {
484
+ dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
485
+
486
+ sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));
457
487
 
458
488
  // decrease global range when it exceeds the max int
459
- int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
460
- sycl::range<3> block_nums(1, 1, num_blocks);
461
- sycl::range<3> local_range(1, 1, local_size);
462
- {
463
- dpct::has_capability_or_fail(stream->get_device(),
464
- {sycl::aspect::fp16});
489
+ // TODO: Downsample logic is separated from the kernel, a rewrite is desirable
490
+ int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);
491
+ sycl::range<3> workgroup_size(1, 1, downsized_workgroup);
465
492
 
466
- stream->parallel_for(
467
- sycl::nd_range<3>(block_nums * local_range, local_range),
468
- [=](sycl::nd_item<3> item_ct1) {
469
- convert_unary<src_t>(vx, y, k, item_ct1);
470
- });
471
- }
493
+ queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {
494
+ convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);
495
+ });
472
496
  }
473
497
 
474
- to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
498
+ template <typename src_t, typename dst_t>
499
+ static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {
500
+ convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
501
+ }
502
+
503
+ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
475
504
  switch (type) {
476
505
  case GGML_TYPE_Q4_0:
477
506
  if (dst->src[0]->extra &&
@@ -493,7 +522,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
493
522
  case GGML_TYPE_Q3_K:
494
523
  return dequantize_row_q3_K_sycl;
495
524
  case GGML_TYPE_Q4_K:
496
- return dequantize_row_q4_K_sycl;
525
+ if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
526
+ return dequantize_row_q4_K_sycl_reorder;
527
+ } else {
528
+ return dequantize_row_q4_K_sycl;
529
+ }
497
530
  case GGML_TYPE_Q5_K:
498
531
  return dequantize_row_q5_K_sycl;
499
532
  case GGML_TYPE_Q6_K:
@@ -545,7 +578,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
545
578
  case GGML_TYPE_Q3_K:
546
579
  return dequantize_row_q3_K_sycl;
547
580
  case GGML_TYPE_Q4_K:
548
- return dequantize_row_q4_K_sycl;
581
+ if (dst->src[0]->extra &&
582
+ ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
583
+ return dequantize_row_q4_K_sycl_reorder;
584
+ } else {
585
+ return dequantize_row_q4_K_sycl;
586
+ }
549
587
  case GGML_TYPE_Q5_K:
550
588
  return dequantize_row_q5_K_sycl;
551
589
  case GGML_TYPE_Q6_K:
@@ -574,3 +612,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
574
612
  return nullptr;
575
613
  }
576
614
  }
615
+
616
+ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
617
+ switch (type) {
618
+ case GGML_TYPE_F32:
619
+ return convert_unary_nc_sycl<float>;
620
+ default:
621
+ return nullptr;
622
+ }
623
+ }
@@ -1,6 +1,6 @@
1
1
  //
2
2
  // MIT license
3
- // Copyright (C) 2024 Intel Corporation
3
+ // Copyright (C) 2025 Intel Corporation
4
4
  // SPDX-License-Identifier: MIT
5
5
  //
6
6
 
@@ -16,12 +16,19 @@
16
16
  #include "common.hpp"
17
17
 
18
18
  template <typename T>
19
- using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
20
- int64_t k, dpct::queue_ptr stream);
21
- typedef to_t_sycl_t<float> to_fp32_sycl_t;
19
+ using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream);
20
+ typedef to_t_sycl_t<float> to_fp32_sycl_t;
22
21
  typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
23
22
 
24
- to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
25
- to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);
23
+ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst);
24
+ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst);
26
25
 
27
- #endif // GGML_SYCL_CONVERT_HPP
26
+ // Nc = Non-contiguous
27
+ template <typename T>
28
+ using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
29
+ int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
30
+
31
+ typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
32
+ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
33
+
34
+ #endif // GGML_SYCL_CONVERT_HPP
@@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
357
357
  }
358
358
  #endif
359
359
 
360
+ template <typename dst_t>
361
+ inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
362
+ const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {
363
+ const int is = 2 * il;
364
+ constexpr int n = 4;
365
+
366
+ uint8_t sc, m;
367
+ get_scale_min_k4(is + 0, scales_local, sc, m);
368
+ const float d1 = dall * sc;
369
+ const float m1 = dmin * m;
370
+
371
+ get_scale_min_k4(is + 1, scales_local, sc, m);
372
+ const float d2 = dall * sc;
373
+ const float m2 = dmin * m;
374
+
375
+ sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
376
+ for (int l = 0; l < n; ++l) {
377
+ y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
378
+ y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
379
+ }
380
+ }
381
+
360
382
  template<typename dst_t>
361
383
  static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
362
384
  uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
@@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
365
387
  const int64_t i = item_ct1.get_group(2);
366
388
 
367
389
  #if QK_K == 256
368
- // assume 32 threads
369
390
  const int64_t tid = item_ct1.get_local_id(2);
370
- const int64_t il = tid/8;
371
- const int64_t ir = tid%8;
372
- const int64_t is = 2*il;
373
- const int64_t n = 4;
391
+ const int64_t il = tid / 8;
392
+ const int64_t ir = tid % 8;
374
393
 
375
- dst_t * y = yy + i*QK_K + 64*il + n*ir;
394
+ dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
376
395
 
377
396
  const sycl::half2 dm = x[i].dm;
378
397
  const float dall = dm[0];
379
398
  const float dmin = dm[1];
380
399
 
381
- if (tid < 12)
400
+ if (tid < 12) {
382
401
  scales_local[tid] = x[i].scales[tid];
383
- item_ct1.barrier(sycl::access::fence_space::local_space);
384
-
385
- uint8_t sc, m;
386
- get_scale_min_k4(is + 0, scales_local, sc, m);
387
- const float d1 = dall * sc;
388
- const float m1 = dmin * m;
389
- get_scale_min_k4(is + 1, scales_local, sc, m);
390
- const float d2 = dall * sc;
391
- const float m2 = dmin * m;
392
-
393
- sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
394
- for (int l = 0; l < n; ++l) {
395
- y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
396
- y[l +32] = d2 * (q_vec[l] >> 4) - m2;
397
402
  }
403
+
404
+ item_ct1.barrier(sycl::access::fence_space::local_space);
405
+ dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);
398
406
  #else
399
407
  const int64_t tid = item_ct1.get_local_id(2);
400
408
  const uint8_t * q = x[i].qs;
@@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
406
414
  #endif
407
415
  }
408
416
 
417
+ template <typename dst_t>
418
+ static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
419
+ const sycl::nd_item<1> & item_ct1, int64_t nb) {
420
+ const int64_t i = item_ct1.get_group(0); // block index
421
+ const int64_t tid = item_ct1.get_local_id(0); // thread index within block
422
+ const int64_t il = tid / 8;
423
+ const int64_t ir = tid % 8;
424
+
425
+ dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
426
+
427
+ const uint8_t * base = static_cast<const uint8_t *>(vx);
428
+ const size_t qs_offset = i * (QK_K / 2);
429
+ const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
430
+ const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
431
+
432
+ const uint8_t * qs_ptr = base + qs_offset;
433
+ const uint8_t * scales_ptr = base + scales_offset;
434
+ ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
435
+
436
+ const float dall = dm_values.x();
437
+ const float dmin = dm_values.y();
438
+
439
+ if (tid < 12) {
440
+ scales_local[tid] = scales_ptr[tid];
441
+ }
442
+
443
+ item_ct1.barrier(sycl::access::fence_space::local_space);
444
+ dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);
445
+ }
446
+
409
447
  template<typename dst_t>
410
448
  static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
411
449
  const sycl::nd_item<3> &item_ct1) {
@@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
1129
1129
  dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1130
1130
  break;
1131
1131
  case GGML_TYPE_Q4_K:
1132
- dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1132
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1133
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1134
+ // reorder is currently not supported for dmmv
1135
+ GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
1136
+ } else {
1137
+ dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1138
+ }
1133
1139
  break;
1134
1140
  case GGML_TYPE_Q5_K:
1135
1141
  dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
@@ -655,7 +655,6 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
655
655
  }
656
656
  default:
657
657
  GGML_ABORT("GGML tensor type not supported!\n");
658
- break;
659
658
  }
660
659
  }
661
660
 
@@ -688,7 +687,6 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
688
687
  }
689
688
  default:
690
689
  GGML_ABORT("GGML tensor type not supported!\n");
691
- break;
692
690
  }
693
691
  }
694
692
 
@@ -722,7 +720,6 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
722
720
  }
723
721
  default:
724
722
  GGML_ABORT("GGML tensor type not supported!\n");
725
- break;
726
723
  }
727
724
  }
728
725
 
@@ -754,7 +751,6 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
754
751
  }
755
752
  default:
756
753
  GGML_ABORT("GGML tensor type not supported!\n");
757
- break;
758
754
  }
759
755
  }
760
756
 
@@ -786,7 +782,6 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
786
782
  }
787
783
  default:
788
784
  GGML_ABORT("GGML tensor type not supported!\n");
789
- break;
790
785
  }
791
786
  }
792
787
 
@@ -818,7 +813,6 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
818
813
  }
819
814
  default:
820
815
  GGML_ABORT("GGML tensor type not supported!\n");
821
- break;
822
816
  }
823
817
  }
824
818
 
@@ -850,7 +844,6 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst
850
844
  }
851
845
  default:
852
846
  GGML_ABORT("GGML tensor type not supported!\n");
853
- break;
854
847
  }
855
848
  }
856
849
 
@@ -883,7 +876,6 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
883
876
  }
884
877
  default:
885
878
  GGML_ABORT("GGML tensor type not supported!\n");
886
- break;
887
879
  }
888
880
  }
889
881
 
@@ -917,7 +909,6 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso
917
909
  }
918
910
  default:
919
911
  GGML_ABORT("GGML tensor type not supported!\n");
920
- break;
921
912
  }
922
913
  }
923
914
 
@@ -949,7 +940,6 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor
949
940
  }
950
941
  default:
951
942
  GGML_ABORT("GGML tensor type not supported!\n");
952
- break;
953
943
  }
954
944
  }
955
945
 
@@ -981,7 +971,6 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
981
971
  }
982
972
  default:
983
973
  GGML_ABORT("GGML tensor type not supported!\n");
984
- break;
985
974
  }
986
975
  }
987
976
 
@@ -1013,7 +1002,6 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
1013
1002
  }
1014
1003
  default:
1015
1004
  GGML_ABORT("GGML tensor type not supported!\n");
1016
- break;
1017
1005
  }
1018
1006
  }
1019
1007
 
@@ -1045,7 +1033,6 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *
1045
1033
  }
1046
1034
  default:
1047
1035
  GGML_ABORT("GGML tensor type not supported!\n");
1048
- break;
1049
1036
  }
1050
1037
  }
1051
1038
 
@@ -1078,7 +1065,6 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst
1078
1065
  }
1079
1066
  default:
1080
1067
  GGML_ABORT("GGML tensor type not supported!\n");
1081
- break;
1082
1068
  }
1083
1069
  }
1084
1070
 
@@ -1110,7 +1096,6 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
1110
1096
  }
1111
1097
  default:
1112
1098
  GGML_ABORT("GGML tensor type not supported!\n");
1113
- break;
1114
1099
  }
1115
1100
  }
1116
1101
 
@@ -1142,7 +1127,6 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
1142
1127
  }
1143
1128
  default:
1144
1129
  GGML_ABORT("GGML tensor type not supported!\n");
1145
- break;
1146
1130
  }
1147
1131
  }
1148
1132
 
@@ -1174,7 +1158,6 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst
1174
1158
  }
1175
1159
  default:
1176
1160
  GGML_ABORT("GGML tensor type not supported!\n");
1177
- break;
1178
1161
  }
1179
1162
  }
1180
1163
 
@@ -1206,7 +1189,6 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
1206
1189
  }
1207
1190
  default:
1208
1191
  GGML_ABORT("GGML tensor type not supported!\n");
1209
- break;
1210
1192
  }
1211
1193
  }
1212
1194
 
@@ -1241,7 +1223,6 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor
1241
1223
  }
1242
1224
  default:
1243
1225
  GGML_ABORT("GGML tensor type not supported!\n");
1244
- break;
1245
1226
  }
1246
1227
  }
1247
1228
 
@@ -1273,7 +1254,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
1273
1254
  }
1274
1255
  default:
1275
1256
  GGML_ABORT("GGML tensor type not supported!\n");
1276
- break;
1277
1257
  }
1278
1258
  }
1279
1259
 
@@ -1315,7 +1295,6 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *
1315
1295
  }
1316
1296
  default:
1317
1297
  GGML_ABORT("GGML tensor type not supported!\n");
1318
- break;
1319
1298
  }
1320
1299
  }
1321
1300
 
@@ -1350,7 +1329,6 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
1350
1329
  }
1351
1330
  default:
1352
1331
  GGML_ABORT("GGML tensor type not supported!\n");
1353
- break;
1354
1332
  }
1355
1333
  }
1356
1334
 
@@ -1388,7 +1366,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds
1388
1366
  }
1389
1367
  default:
1390
1368
  GGML_ABORT("GGML tensor type not supported!\n");
1391
- break;
1392
1369
  }
1393
1370
  }
1394
1371
 
@@ -32,16 +32,36 @@ public:
32
32
  else static_assert(0);
33
33
  }
34
34
 
35
- static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36
- const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
35
+ // matrix A has m rows, k columns
36
+ // matrix B has k rows, n columns
37
+ // nra - number of elements to skip when moving into next row in A
38
+ // nrb - number of elements to skip when moving into next row in B
39
+ // nca - number of elements to skip when moving into next column in A
40
+ // ncb - number of elements to skip when moving into next column in B
41
+ // stride_a - number of elements to skip when moving to next A matrix
42
+ // stride_b - number of elements to skip when moving to next B matrix
43
+ // batches_a - number of A matrices
44
+ // batches_b - number of B matrices
45
+ static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
46
+ const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
47
+ const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
48
+ void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
49
+
37
50
  auto stream = ctx.stream_dnnl(q);
38
51
  auto eng = ctx.engine_dnnl(q);
39
- dnnl::memory::dims a_dims = { m, k };
40
- dnnl::memory::dims b_dims = { k, n };
41
- dnnl::memory::dims c_dims = { m, n };
42
- const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
43
- const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
44
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
52
+
53
+ // { # strides, # rows, # columns }
54
+ dnnl::memory::dims a_dims = { batches_a, m, k };
55
+ dnnl::memory::dims b_dims = { batches_b, k, n };
56
+ dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
57
+
58
+ // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
59
+ dnnl::memory::dims a_strides = { stride_a, nra, nca };
60
+ dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
61
+
62
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
63
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
64
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
45
65
 
46
66
  dnnl::primitive_attr primitive_attr;
47
67
  primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
@@ -63,6 +83,15 @@ public:
63
83
 
64
84
  matmul_prim.execute(stream, matmul_args);
65
85
  }
86
+
87
+ // matrices A and B are column major, both having k rows
88
+ // matrix A has m column, matrix B has n columns
89
+ // output: column major matrix C = A transposed * B
90
+ static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
91
+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
92
+
93
+ gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
94
+ }
66
95
  };
67
96
 
68
97
  #endif